diff --git a/README.md b/README.md index 0cf72ce..f6f236a 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ classDiagram + module_capacity_kw : Float + ml_id : Integer ≪ U ≫ + client_uuid : UUID ≪ FK ≫ + + ml_model_uuid : UUID ≪ FK ≫ } class ClientSQL{ @@ -208,6 +209,7 @@ classDiagram SiteGroupSQL "1" -- "N" SiteGroupSiteSQL : contains SiteSQL "1" -- "N" GenerationSQL : generates SiteSQL "1" -- "N" ForecastSQL : forecasts + SiteSQL "N" -- "0" MLModelSQL : ml_model ForecastSQL "1" -- "N" ForecastValueSQL : contains MLModelSQL "1" -- "N" ForecastValueSQL : forecasts SiteSQL "1" -- "N" InverterSQL : contains diff --git a/pvsite_datamodel/sqlmodels.py b/pvsite_datamodel/sqlmodels.py index c41bf3c..9536527 100644 --- a/pvsite_datamodel/sqlmodels.py +++ b/pvsite_datamodel/sqlmodels.py @@ -35,6 +35,8 @@ class MLModelSQL(Base, CreatedMixin): "ForecastValueSQL", back_populates="ml_model" ) + sites: Mapped[List["SiteSQL"]] = relationship("SiteSQL", back_populates="ml_model") + class UserSQL(Base, CreatedMixin): """Class representing the users table. @@ -181,6 +183,13 @@ class SiteSQL(Base, CreatedMixin): comment="The UUID of the client this site belongs to", ) + ml_model_uuid = sa.Column( + UUID(as_uuid=True), + sa.ForeignKey("ml_model.model_uuid"), + nullable=True, + comment="The ML Model which should be used for this site", + ) + forecasts: Mapped[List["ForecastSQL"]] = relationship("ForecastSQL", back_populates="site") generation: Mapped[List["GenerationSQL"]] = relationship("GenerationSQL") inverters: Mapped[List["InverterSQL"]] = relationship( @@ -190,6 +199,7 @@ class SiteSQL(Base, CreatedMixin): "SiteGroupSQL", secondary="site_group_sites", back_populates="sites" ) client: Mapped[List["ClientSQL"]] = relationship("ClientSQL", back_populates="sites") + ml_model: Mapped[Optional[MLModelSQL]] = relationship("MLModelSQL", back_populates="sites") class ClientSQL(Base, CreatedMixin): diff --git a/pvsite_datamodel/write/user_and_site.py b/pvsite_datamodel/write/user_and_site.py index 0ff7e0e..27c6116 100644 --- a/pvsite_datamodel/write/user_and_site.py +++ b/pvsite_datamodel/write/user_and_site.py @@ -10,7 +10,7 @@ from sqlalchemy.sql.functions import func from pvsite_datamodel.pydantic_models import PVSiteEditMetadata -from pvsite_datamodel.read import get_user_by_email +from pvsite_datamodel.read import get_or_create_model, get_site_by_uuid, get_user_by_email from pvsite_datamodel.sqlmodels import ( ForecastSQL, ForecastValueSQL, @@ -376,3 +376,16 @@ def delete_site_group(session: Session, site_group_name: str) -> str: session.commit() return message + + +def assign_model_name_to_site(session: Session, site_uuid, model_name): + """ + Assign a model name to a site. + """ + + site = get_site_by_uuid(session=session, site_uuid=site_uuid) + + model = get_or_create_model(session=session, name=model_name) + + site.ml_model_uuid = model.model_uuid + session.commit() diff --git a/tests/conftest.py b/tests/conftest.py index 51893c0..0d5116c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,10 @@ def engine(): os.environ["DB_URL"] = url command.upgrade(alembic_cfg, "head") + # If you haven't run alembic migration yet, you might want to just run the tests with this + # from pvsite_datamodel.sqlmodels import Base + # Base.metadata.create_all(engine) + yield engine diff --git a/tests/test_write.py b/tests/test_write.py index eba1d61..248a57b 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -18,6 +18,7 @@ from pvsite_datamodel.write.generation import insert_generation_values from pvsite_datamodel.write.user_and_site import ( add_site_to_site_group, + assign_model_name_to_site, change_user_site_group, create_site, create_site_group, @@ -342,3 +343,16 @@ def test_assign_site_to_client(db_session): f"Site with site uuid {site.site_uuid} successfully assigned " f"to the client {client.client_name}" ) + + +def test_assign_model_name_to_site(db_session): + """Test to assign a model name to a site""" + site = make_fake_site(db_session=db_session) + + assign_model_name_to_site(db_session, site.site_uuid, "test_model") + + assert site.ml_model.name == "test_model" + + assign_model_name_to_site(db_session, site.site_uuid, "test_model_2") + + assert site.ml_model.name == "test_model_2"