diff --git a/alembic/versions/2a6e6975cd72_remove_uniqueness_from_ml_id_for_site_.py b/alembic/versions/2a6e6975cd72_remove_uniqueness_from_ml_id_for_site_.py new file mode 100644 index 0000000..d1e171d --- /dev/null +++ b/alembic/versions/2a6e6975cd72_remove_uniqueness_from_ml_id_for_site_.py @@ -0,0 +1,39 @@ +"""Remove uniqueness from ml id for site table + +Revision ID: 2a6e6975cd72 +Revises: 34891d466985 +Create Date: 2024-09-20 15:23:59.973409 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '2a6e6975cd72' +down_revision = '34891d466985' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('sites', 'ml_id', + existing_type=sa.INTEGER(), + comment='Auto-incrementing integer ID of the site for use in ML training', + existing_nullable=False, + autoincrement=True) + op.drop_constraint('sites_ml_id_key', 'sites', type_='unique') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_unique_constraint('sites_ml_id_key', 'sites', ['ml_id']) + op.alter_column('sites', 'ml_id', + existing_type=sa.INTEGER(), + comment=None, + existing_comment='Auto-incrementing integer ID of the site for use in ML training', + existing_nullable=False, + autoincrement=True) + # ### end Alembic commands ### diff --git a/poetry.lock b/poetry.lock index 8c894e4..00bec0e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alembic" @@ -1470,4 +1470,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "f7e4ba9700ecfa7286bc318f3a1e9946a3ae683947531537ec1b90ef8df90c39" +content-hash = "1bdab9016adf485c3a8a1da6a25cb47847f08a127ee6f4ca56e4753079a8871b" diff --git a/pvsite_datamodel/sqlmodels.py b/pvsite_datamodel/sqlmodels.py index 4577f11..d411d5b 100644 --- a/pvsite_datamodel/sqlmodels.py +++ b/pvsite_datamodel/sqlmodels.py @@ -165,7 +165,6 @@ class SiteSQL(Base, CreatedMixin): sa.Integer, autoincrement=True, nullable=False, - unique=True, comment="Auto-incrementing integer ID of the site for use in ML training", ) diff --git a/pvsite_datamodel/write/user_and_site.py b/pvsite_datamodel/write/user_and_site.py index 0449613..dfbeb3e 100644 --- a/pvsite_datamodel/write/user_and_site.py +++ b/pvsite_datamodel/write/user_and_site.py @@ -74,6 +74,7 @@ def create_site( tilt: Optional[float] = None, inverter_capacity_kw: Optional[float] = None, module_capacity_kw: Optional[float] = None, + ml_id: Optional[int] = None, ) -> [SiteSQL, str]: """ Create a site and adds it to the database. @@ -93,6 +94,7 @@ def create_site( :param tilt: tilt of site, default is 35 :param inverter_capacity_kw: inverter capacity of site in kw :param module_capacity_kw: module capacity of site in kw + :param ml_id: internal ML modelling id """ max_ml_id = session.query(func.max(SiteSQL.ml_id)).scalar() @@ -133,7 +135,7 @@ def create_site( dno = json.dumps(dno) site = SiteSQL( - ml_id=max_ml_id + 1, + ml_id=ml_id if ml_id else max_ml_id + 1, client_site_id=client_site_id, client_site_name=client_site_name, latitude=latitude, diff --git a/tests/test_write.py b/tests/test_write.py index fa7d1d5..d66e789 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -203,6 +203,34 @@ def test_create_new_site_twice(db_session): assert site_2.ml_id == 2 +# test for create_site to check ml_id duplicates are allowed for separate clients +def test_ml_id_duplicate_for_unique_clients(db_session): + """Test create sites with duplicate ml_id for different clients""" + + site_1, _ = create_site( + session=db_session, + client_site_id=6932, + client_site_name="test_site_name_1", + latitude=1.0, + longitude=1.0, + capacity_kw=1.0, + ml_id=1, + ) + + site_2, _ = create_site( + session=db_session, + client_site_id=6933, + client_site_name="test_site_name_2", + latitude=1.0, + longitude=1.0, + capacity_kw=1.0, + ml_id=1, + ) + + assert site_1.ml_id == 1 + assert site_2.ml_id == 1 + + def test_create_user(db_session): "Test to create a new user."