Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forecast by asset type #8

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ COPY --from=builder /venv /venv
COPY --from=builder /app/dist .
RUN . /venv/bin/activate && pip install *.whl

ENTRYPOINT ["app"]
ENTRYPOINT ["app", "--write-to-db"]
44 changes: 26 additions & 18 deletions india_forecast_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from pvsite_datamodel import DatabaseConnection
from pvsite_datamodel.read import get_sites_by_country
from pvsite_datamodel.sqlmodels import SiteSQL
from pvsite_datamodel.write import insert_forecast_values
from sqlalchemy.orm import Session

Expand All @@ -19,31 +20,33 @@
log = logging.getLogger(__name__)


def get_site_ids(db_session: Session) -> list[str]:
def get_sites(db_session: Session) -> list[SiteSQL]:
"""
Gets all avaiable site_ids in India
Gets all available sites in India

Args:
db_session: A SQLAlchemy session

Returns:
A list of site_ids
A list of SiteSQL objects
"""

sites = get_sites_by_country(db_session, country="india")
return sites

return [s.site_uuid for s in sites]


def get_model():
def get_model(asset_type: str):
"""
Instantiates and returns the forecast model ready for running inference

Args:
asset_type: One or "pv" or "wind"

Returns:
A forecasting model
"""

model = DummyModel()
model = DummyModel(asset_type)
return model


Expand Down Expand Up @@ -151,17 +154,22 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str):

# 1. Get sites
log.info("Getting sites...")
site_ids = get_site_ids(session)
log.info(f"Found {len(site_ids)} sites")

# 2. Load model
log.info("Loading model...")
model = get_model()
log.info("Loaded model")

# 3. Run model for each site
for site_id in site_ids:
log.info(f"Running model for site={site_id}...")
sites = get_sites(session)
log.info(f"Found {len(sites)} sites")

# 2. Load models
log.info("Loading models...")
pv_model = get_model("pv")
log.info("Loaded PV model")
wind_model = get_model("wind")
log.info("Loaded wind model")

for site in sites:
# 3. Run model for each site
site_id = site.site_uuid
asset_type = site.asset_type.name
log.info(f"Running {asset_type} model for site={site_id}...")
model = wind_model if asset_type == "wind" else pv_model
forecast_values = run_model(model=model, site_id=site_id, timestamp=timestamp)

if forecast_values is None:
Expand Down
12 changes: 7 additions & 5 deletions india_forecast_app/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class DummyModel:
"""
Dummy model that emulates the capabilities expected by a real model
"""

@property
def version(self):
"""Version number"""
return "0.0.0"

def __init__(self):
def __init__(self, asset_type: str):
"""Initializer for the model"""
pass
self.asset_type = asset_type

def predict(self, site_id: str, timestamp: dt.datetime):
"""Make a prediction for the model"""
Expand All @@ -35,7 +35,9 @@ def _generate_dummy_forecast(self, timestamp: dt.datetime):

for i in range(numSteps):
time = start + i * step
_yield = _basicSolarYieldFunc(int(time.timestamp()))
gen_func = _basicSolarYieldFunc if self.asset_type == "pv" else _basicWindYieldFunc
_yield = gen_func(int(time.timestamp()))

values.append(
{
"start_utc": time,
Expand Down Expand Up @@ -99,6 +101,6 @@ def _basicSolarYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float:

def _basicWindYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float:
"""Gets a fake wind yield for the input time."""
output = min(scaleFactor, scaleFactor * 10 * random.random())
output = min(scaleFactor, scaleFactor * random.random())

return output
26 changes: 11 additions & 15 deletions scripts/seed_local_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from pvsite_datamodel.connection import DatabaseConnection
from pvsite_datamodel.sqlmodels import Base
from pvsite_datamodel.write.user_and_site import (
add_site_to_site_group,
create_site,
create_site_group,
create_user,
)


Expand Down Expand Up @@ -45,29 +42,28 @@ def seed_db():
Base.metadata.create_all(engine)

with db_conn.get_session() as session:
dummy_user = "[email protected]"

print("Seeding database")
site_group = create_site_group(session, site_group_name="dummy_site_group")

_ = create_user(
session, email=dummy_user, site_group_name=site_group.site_group_name
)

site, _ = create_site(
session,
client_site_id=1234,
client_site_name="dummy_site",
client_site_id=1,
client_site_name="dummy_site_1",
latitude=0.0,
longitude=0.0,
capacity_kw=10.0,
asset_type="pv",
country="india",
)

add_site_to_site_group(
site, _ = create_site(
session,
site_uuid=site.site_uuid,
site_group_name=site_group.site_group_name,
client_site_id=2,
client_site_name="dummy_site_2",
latitude=0.0,
longitude=0.0,
capacity_kw=10.0,
asset_type="wind",
country="india",
)

print("Database successfully seeded")
37 changes: 24 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,30 @@ def db_data(engine):

with engine.connect() as connection:
with Session(bind=connection) as session:
n_sites = 3

# Sites
for i in range(n_sites):
site = SiteSQL(
client_site_id=i + 1,
latitude=51,
longitude=3,
capacity_kw=4,
ml_id=i,
country="india"
)
session.add(site)

# PV site
site = SiteSQL(
client_site_id=1,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=1,
asset_type="pv",
country="india"
)
session.add(site)

# Wind site
site = SiteSQL(
client_site_id=2,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=2,
asset_type="wind",
country="india"
)
session.add(site)

session.commit()

Expand Down
38 changes: 20 additions & 18 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,45 @@

import pytest
from pvsite_datamodel.read import get_all_sites
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL

from india_forecast_app.app import app, get_model, get_site_ids, run_model, save_forecast
from india_forecast_app.app import app, get_model, get_sites, run_model, save_forecast
from india_forecast_app.model import DummyModel

from ._utils import run_click_script


def test_get_site_ids(db_session):
def test_get_sites(db_session):
"""Test for correct site ids"""

site_ids = get_site_ids(db_session)
sites = get_sites(db_session)
sites = sorted(sites, key=lambda s: s.client_site_id)

assert len(site_ids) == 3
for site_id in site_ids:
assert isinstance(site_id, uuid.UUID)
assert len(sites) == 2
for site in sites:
assert isinstance(site.site_uuid, uuid.UUID)
assert sites[0].asset_type.name == "pv"
assert sites[1].asset_type.name == "wind"


def test_get_model():
@pytest.mark.parametrize("asset_type", ["pv", "wind"])
def test_get_model(asset_type):
"""Test for getting valid model"""

model = get_model()
model = get_model(asset_type)

assert hasattr(model, 'version')
assert isinstance(model.version, str)
assert hasattr(model, 'predict')


def test_run_model(db_session):
"""Test for running model"""

site = db_session.query(SiteSQL).first()
model = DummyModel()
@pytest.mark.parametrize("asset_type", ["pv", "wind"])
def test_run_model(db_session, asset_type):
"""Test for running PV and wind models"""

forecast = run_model(
model=model,
site_id=site.site_uuid,
model=DummyModel(asset_type),
site_id=str(uuid.uuid4()),
timestamp=dt.datetime.now(tz=dt.UTC)
)

Expand Down Expand Up @@ -88,8 +90,8 @@ def test_app(write_to_db, db_session):
assert result.exit_code == 0

if write_to_db:
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 192)
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 192)
else:
assert db_session.query(ForecastSQL).count() == init_n_forecasts
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values
Loading