diff --git a/nowcasting_datamodel/fake.py b/nowcasting_datamodel/fake.py index f71148ba..f6e37f50 100644 --- a/nowcasting_datamodel/fake.py +++ b/nowcasting_datamodel/fake.py @@ -17,7 +17,8 @@ ) from nowcasting_datamodel.models.forecast import ForecastSQL, ForecastValueSQL from nowcasting_datamodel.models.gsp import GSPYieldSQL -from nowcasting_datamodel.read.read import get_location, get_model +from nowcasting_datamodel.read.read import get_location +from nowcasting_datamodel.read.read_models import get_model from nowcasting_datamodel.read.read_metric import get_datetime_interval from nowcasting_datamodel.save.update import change_forecast_value_to_latest diff --git a/nowcasting_datamodel/models/convert.py b/nowcasting_datamodel/models/convert.py index 6918e93b..9e5fd888 100644 --- a/nowcasting_datamodel/models/convert.py +++ b/nowcasting_datamodel/models/convert.py @@ -16,8 +16,8 @@ from nowcasting_datamodel.read.read import ( get_latest_input_data_last_updated, get_location, - get_model, ) +from nowcasting_datamodel.read.read_models import get_model logger = logging.getLogger() diff --git a/nowcasting_datamodel/read/read.py b/nowcasting_datamodel/read/read.py index 6a669b61..95870cdf 100644 --- a/nowcasting_datamodel/read/read.py +++ b/nowcasting_datamodel/read/read.py @@ -759,47 +759,6 @@ def get_all_locations(session: Session, gsp_ids: List[int] = None) -> List[Locat return locations -def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL: - """ - Get model object from name and version - - :param session: database session - :param name: name of the model - :param version: version of the model - - return: Model object - - """ - - # start main query - query = session.query(MLModelSQL) - - # filter on gsp_id - query = query.filter(MLModelSQL.name == name) - if version is not None: - query = query.filter(MLModelSQL.version == version) - - # gets the latest version - query = query.order_by(MLModelSQL.version.desc()) - - # get all results - models = query.all() - - if len(models) == 0: - logger.debug( - f"Model for name {name} and version {version} does not exist so going to add it" - ) - - model = MLModelSQL(name=name, version=version) - session.add(model) - session.commit() - - else: - model = models[0] - - return model - - def get_pv_system( session: Session, pv_system_id: int, provider: Optional[str] = "pvoutput.org" ) -> PVSystemSQL: diff --git a/nowcasting_datamodel/read/read_models.py b/nowcasting_datamodel/read/read_models.py index e34868ee..02f15dfb 100644 --- a/nowcasting_datamodel/read/read_models.py +++ b/nowcasting_datamodel/read/read_models.py @@ -3,8 +3,12 @@ from datetime import datetime from typing import Optional +from sqlalchemy.orm import Session + +from nowcasting_datamodel.models import MLModelSQL from nowcasting_datamodel.models.forecast import ForecastSQL from nowcasting_datamodel.models.models import MLModelSQL +from nowcasting_datamodel.read.read import logger def get_models( @@ -37,3 +41,44 @@ def get_models( models = query.all() return models + + +def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL: + """ + Get model object from name and version + + :param session: database session + :param name: name of the model + :param version: version of the model + + return: Model object + + """ + + # start main query + query = session.query(MLModelSQL) + + # filter on gsp_id + query = query.filter(MLModelSQL.name == name) + if version is not None: + query = query.filter(MLModelSQL.version == version) + + # gets the latest version + query = query.order_by(MLModelSQL.version.desc()) + + # get all results + models = query.all() + + if len(models) == 0: + logger.debug( + f"Model for name {name} and version {version} does not exist so going to add it" + ) + + model = MLModelSQL(name=name, version=version) + session.add(model) + session.commit() + + else: + model = models[0] + + return model diff --git a/tests/read/test_read.py b/tests/read/test_read.py index c65b8f81..e4edb97a 100644 --- a/tests/read/test_read.py +++ b/tests/read/test_read.py @@ -14,7 +14,6 @@ from nowcasting_datamodel.models import ( InputDataLastUpdatedSQL, LocationSQL, - MLModel, PVSystem, Status, national_gb_label, @@ -36,7 +35,6 @@ get_latest_national_forecast, get_latest_status, get_location, - get_model, get_pv_system, update_latest_input_data_last_updated, ) @@ -70,16 +68,6 @@ def test_get_national_location(db_session): _ = get_location(session=db_session, gsp_id=0, label="test_label") -def test_get_model(db_session): - model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9") - model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9") - - assert model_read_1.name == model_read_2.name - assert model_read_1.version == model_read_2.version - - _ = MLModel.from_orm(model_read_2) - - def test_get_forecast(db_session, forecasts): forecast_read = get_latest_forecast(session=db_session) diff --git a/tests/read/test_read_models.py b/tests/read/test_read_models.py index 45aa0a52..012e91fd 100644 --- a/tests/read/test_read_models.py +++ b/tests/read/test_read_models.py @@ -2,8 +2,8 @@ from datetime import datetime, timedelta, timezone -from nowcasting_datamodel.read.read_models import get_models -from nowcasting_datamodel.read.read import get_model +from nowcasting_datamodel.models import MLModel +from nowcasting_datamodel.read.read_models import get_models, get_model def test_get_models(db_session): @@ -44,3 +44,13 @@ def testget_models_after_created(db_session): forecast_created_utc=datetime.now(tz=timezone.utc) + timedelta(days=1), ) assert len(models) == 0 + + +def test_get_model(db_session): + model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9") + model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9") + + assert model_read_1.name == model_read_2.name + assert model_read_1.version == model_read_2.version + + _ = MLModel.from_orm(model_read_2) diff --git a/tests/save/test_update.py b/tests/save/test_update.py index 869cf157..7537f380 100644 --- a/tests/save/test_update.py +++ b/tests/save/test_update.py @@ -8,7 +8,7 @@ make_fake_forecast_value, make_fake_forecasts, ) -from nowcasting_datamodel.read.read import get_model +from nowcasting_datamodel.read.read_models import get_model from nowcasting_datamodel.models import ForecastValueSevenDaysSQL from nowcasting_datamodel.models.forecast import ( ForecastSQL,