Skip to content

Commit

Permalink
add optional ml model to sites table
Browse files Browse the repository at this point in the history
add migration
add function to assign model to site + test
  • Loading branch information
peterdudfield committed Nov 25, 2024
1 parent 2dd2d3a commit 0b77cdb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ classDiagram
+ module_capacity_kw : Float
+ ml_id : Integer ≪ U ≫
+ client_uuid : UUID ≪ FK ≫
+ ml_model_uuid : UUID ≪ FK ≫
}
class ClientSQL{
Expand Down Expand Up @@ -208,6 +209,7 @@ classDiagram
SiteGroupSQL "1" -- "N" SiteGroupSiteSQL : contains
SiteSQL "1" -- "N" GenerationSQL : generates
SiteSQL "1" -- "N" ForecastSQL : forecasts
SiteSQL "N" -- "0" MLModelSQL : ml_model
ForecastSQL "1" -- "N" ForecastValueSQL : contains
MLModelSQL "1" -- "N" ForecastValueSQL : forecasts
SiteSQL "1" -- "N" InverterSQL : contains
Expand Down
10 changes: 10 additions & 0 deletions pvsite_datamodel/sqlmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class MLModelSQL(Base, CreatedMixin):
"ForecastValueSQL", back_populates="ml_model"
)

sites: Mapped[List["SiteSQL"]] = relationship("SiteSQL", back_populates="ml_model")


class UserSQL(Base, CreatedMixin):
"""Class representing the users table.
Expand Down Expand Up @@ -181,6 +183,13 @@ class SiteSQL(Base, CreatedMixin):
comment="The UUID of the client this site belongs to",
)

ml_model_uuid = sa.Column(
UUID(as_uuid=True),
sa.ForeignKey("ml_model.model_uuid"),
nullable=True,
comment="The ML Model which should be used for this site",
)

forecasts: Mapped[List["ForecastSQL"]] = relationship("ForecastSQL", back_populates="site")
generation: Mapped[List["GenerationSQL"]] = relationship("GenerationSQL")
inverters: Mapped[List["InverterSQL"]] = relationship(
Expand All @@ -190,6 +199,7 @@ class SiteSQL(Base, CreatedMixin):
"SiteGroupSQL", secondary="site_group_sites", back_populates="sites"
)
client: Mapped[List["ClientSQL"]] = relationship("ClientSQL", back_populates="sites")
ml_model: Mapped[Optional[MLModelSQL]] = relationship("MLModelSQL", back_populates="sites")


class ClientSQL(Base, CreatedMixin):
Expand Down
15 changes: 14 additions & 1 deletion pvsite_datamodel/write/user_and_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy.sql.functions import func

from pvsite_datamodel.pydantic_models import PVSiteEditMetadata
from pvsite_datamodel.read import get_user_by_email
from pvsite_datamodel.read import get_or_create_model, get_site_by_uuid, get_user_by_email
from pvsite_datamodel.sqlmodels import (
ForecastSQL,
ForecastValueSQL,
Expand Down Expand Up @@ -376,3 +376,16 @@ def delete_site_group(session: Session, site_group_name: str) -> str:
session.commit()

return message


def assign_model_name_to_site(session: Session, site_uuid, model_name):
"""
Assign a model name to a site.
"""

site = get_site_by_uuid(session=session, site_uuid=site_uuid)

model = get_or_create_model(session=session, name=model_name)

site.ml_model_uuid = model.model_uuid
session.commit()
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def engine():
os.environ["DB_URL"] = url
command.upgrade(alembic_cfg, "head")

# If you haven't run alembic migration yet, you might want to just run the tests with this
# from pvsite_datamodel.sqlmodels import Base
# Base.metadata.create_all(engine)

yield engine


Expand Down
14 changes: 14 additions & 0 deletions tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pvsite_datamodel.write.generation import insert_generation_values
from pvsite_datamodel.write.user_and_site import (
add_site_to_site_group,
assign_model_name_to_site,
change_user_site_group,
create_site,
create_site_group,
Expand Down Expand Up @@ -342,3 +343,16 @@ def test_assign_site_to_client(db_session):
f"Site with site uuid {site.site_uuid} successfully assigned "
f"to the client {client.client_name}"
)


def test_assign_model_name_to_site(db_session):
"""Test to assign a model name to a site"""
site = make_fake_site(db_session=db_session)

assign_model_name_to_site(db_session, site.site_uuid, "test_model")

assert site.ml_model.name == "test_model"

assign_model_name_to_site(db_session, site.site_uuid, "test_model_2")

assert site.ml_model.name == "test_model_2"

0 comments on commit 0b77cdb

Please sign in to comment.