Skip to content

Commit

Permalink
Issue/add ml model (#150)
Browse files Browse the repository at this point in the history
* #147 add ml model

* add migration for datamodel

* add model filter + tests

* PR comment, rename function
  • Loading branch information
peterdudfield authored Aug 12, 2024
1 parent b46fef5 commit 127e7e4
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 12 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Classes specifying table schemas:
- GenerationSQL
- ForecastSQL
- ForecastValueSQL
- MLModelSQL
- UserSQL
- SiteSQL
- SiteGroupSQL
Expand Down Expand Up @@ -177,13 +178,20 @@ classDiagram
+ url : String
+ user_uuid : UUID ≪ FK ≫
}
class MLModelSQL{
+ uuid : UUID ≪ PK ≫
+ mode_name : String
+ model_version : UUID ≪ FK ≫
}
UserSQL "1" -- "N" SiteGroupSQL : belongs_to
SiteGroupSQL "N" -- "N" SiteSQL : contains
SiteGroupSQL "1" -- "N" SiteGroupSiteSQL : contains
SiteSQL "1" -- "N" GenerationSQL : generates
SiteSQL "1" -- "N" ForecastSQL : forecasts
ForecastSQL "1" -- "N" ForecastValueSQL : contains
MLModelSQL "1" -- "N" ForecastValueSQL : forecasts
SiteSQL "1" -- "N" InverterSQL : contains
UserSQL "1" -- "N" APIRequestSQL : performs_request
class Legend{
Expand Down
47 changes: 47 additions & 0 deletions alembic/versions/34891d466985_add_ml_models_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Add ml models table
Revision ID: 34891d466985
Revises: fb27362e3b6b
Create Date: 2024-08-07 12:26:23.631105
"""
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision = "34891d466985"
down_revision = "fb27362e3b6b"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"ml_model",
sa.Column(
"model_uuid", sa.UUID(), server_default=sa.text("gen_random_uuid()"), nullable=False
),
sa.Column("name", sa.String(), nullable=True),
sa.Column("version", sa.String(), nullable=True),
sa.Column("created_utc", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("model_uuid"),
)
op.add_column(
"forecast_values",
sa.Column(
"ml_model_uuid",
sa.UUID(),
nullable=True,
comment="The ML Model this forcast value belongs to",
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("forecast_values", "ml_model_uuid")
op.drop_table("ml_model")
# ### end Alembic commands ###
1 change: 1 addition & 0 deletions pvsite_datamodel/read/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .generation import get_pv_generation_by_sites, get_pv_generation_by_user_uuids
from .latest_forecast_values import get_latest_forecast_values_by_site
from .model import get_or_create_model
from .site import (
get_all_sites,
get_site_by_client_site_id,
Expand Down
9 changes: 8 additions & 1 deletion pvsite_datamodel/read/latest_forecast_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.orm import Session, contains_eager

from pvsite_datamodel.pydantic_models import ForecastValueSum
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, MLModelSQL, SiteSQL


def get_latest_forecast_values_by_site(
Expand All @@ -21,6 +21,7 @@ def get_latest_forecast_values_by_site(
forecast_horizon_minutes: Optional[int] = None,
day_ahead_hours: Optional[int] = None,
day_ahead_timezone_delta_hours: Optional[float] = 0,
model_name: Optional[str] = None,
) -> Union[dict[uuid.UUID, list[ForecastValueSQL]], List[ForecastValueSum]]:
"""Get the forecast values by input sites, get the latest value.
Expand Down Expand Up @@ -59,6 +60,7 @@ def get_latest_forecast_values_by_site(
As datetimes are stored in UTC, we need to adjust the start_utc when looking at day
ahead forecast. For example a forecast made a 04:00 UTC for 20:00 UTC for India,
is actually a day ahead forcast, as India is 5.5 hours ahead on UTC
:param model_name: optional, filter on forecast values with this model name
"""

if sum_by not in ["total", "dno", "gsp", None]:
Expand Down Expand Up @@ -114,6 +116,11 @@ def get_latest_forecast_values_by_site(
)
)

if model_name is not None:
# join with MLModelSQL to filter on model_name
query = query.join(MLModelSQL)
query = query.filter(MLModelSQL.name == model_name)

# speed up query, so all information is gather in one query, rather than lots of little ones
query = query.options(contains_eager(ForecastValueSQL.forecast)).populate_existing()

Expand Down
52 changes: 52 additions & 0 deletions pvsite_datamodel/read/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
""" Read functions for getting ML models. """
import logging
from typing import Optional

from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import MLModelSQL

logger = logging.getLogger(__name__)


def get_or_create_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL:
"""
Get model object from name and version.
A new model is made if it doesn't not exists
:param session: database session
:param name: name of the model
:param version: version of the model
return: Model object
"""

# start main query
query = session.query(MLModelSQL)

# filter on gsp_id
query = query.filter(MLModelSQL.name == name)
if version is not None:
query = query.filter(MLModelSQL.version == version)

# gets the latest version
query = query.order_by(MLModelSQL.version.desc())

# get all results
models = query.all()

if len(models) == 0:
logger.debug(
f"Model for name {name} and version {version} does not exist so going to add it"
)

model = MLModelSQL(name=name, version=version)
session.add(model)
session.commit()

else:
model = models[0]

return model
25 changes: 24 additions & 1 deletion pvsite_datamodel/sqlmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import enum
import uuid
from datetime import datetime
from typing import List
from typing import List, Optional

import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
Expand All @@ -22,6 +22,20 @@ class CreatedMixin:
created_utc = sa.Column(sa.DateTime, default=lambda: datetime.utcnow())


class MLModelSQL(Base, CreatedMixin):
"""ML model that is being used."""

__tablename__ = "ml_model"

model_uuid = sa.Column(UUID, primary_key=True, server_default=sa.func.gen_random_uuid())
name = sa.Column(sa.String)
version = sa.Column(sa.String)

forecast_values: Mapped[List["ForecastValueSQL"]] = relationship(
"ForecastValueSQL", back_populates="ml_model"
)


class UserSQL(Base, CreatedMixin):
"""Class representing the users table.
Expand Down Expand Up @@ -306,8 +320,17 @@ class ForecastValueSQL(Base, CreatedMixin):
nullable=False,
comment="The forecast sequence this forcast value belongs to",
)
ml_model_uuid = sa.Column(
UUID(as_uuid=True),
sa.ForeignKey("ml_model.model_uuid"),
nullable=True,
comment="The ML Model this forcast value belongs to",
)

forecast: Mapped["ForecastSQL"] = relationship("ForecastSQL", back_populates="forecast_values")
ml_model: Mapped[Optional[MLModelSQL]] = relationship(
"MLModelSQL", back_populates="forecast_values"
)

__table_args__ = (
# Here we assume that we always filter on `horizon_minutes` *for given forecasts*.
Expand Down
11 changes: 11 additions & 0 deletions pvsite_datamodel/write/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

import logging
from typing import Optional

import pandas as pd
from sqlalchemy.orm import Session

from pvsite_datamodel.read.model import get_or_create_model
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL

_log = logging.getLogger(__name__)
Expand All @@ -16,6 +18,8 @@ def insert_forecast_values(
session: Session,
forecast_meta: dict,
forecast_values_df: pd.DataFrame,
ml_model_name: Optional[str] = None,
ml_model_version: Optional[str] = None,
):
"""Insert a dataframe of forecast values and forecast meta info into the database.
Expand All @@ -30,12 +34,19 @@ def insert_forecast_values(
# Flush to get the Forecast's primary key.
session.flush()

if (ml_model_name is not None) and (ml_model_version is not None):
ml_model = get_or_create_model(session, ml_model_name, ml_model_version)
ml_model_uuid = ml_model.uuid
else:
ml_model_uuid = None

rows = forecast_values_df.to_dict("records")
session.bulk_save_objects(
[
ForecastValueSQL(
**row,
forecast_uuid=forecast.forecast_uuid,
ml_model_uuid=ml_model_uuid,
)
for row in rows
]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.fixture(scope="session")
def engine():
"""Database engine fixture."""
with PostgresContainer("postgres:14.5") as postgres:
with PostgresContainer("postgres:15.5") as postgres:
# TODO need to setup postgres database with docker
url = postgres.get_connection_url()
engine = create_engine(url)
Expand Down
4 changes: 3 additions & 1 deletion tests/read/test_get_api_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def test_get_api_requests_for_one_user_end_datetime(db_session):
db_session.add(APIRequestSQL(user_uuid=user.user_uuid, url="test"))

requests_sql = get_api_requests_for_one_user(
session=db_session, email=user.email, end_datetime=dt.datetime.now() - dt.timedelta(hours=1)
session=db_session,
email=user.email,
end_datetime=dt.datetime.now(tz=dt.timezone.utc) - dt.timedelta(hours=1),
)
assert len(requests_sql) == 0
Loading

0 comments on commit 127e7e4

Please sign in to comment.