Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed May 22, 2024
1 parent 7e87c25 commit 6d0a919
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 58 deletions.
3 changes: 2 additions & 1 deletion nowcasting_datamodel/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
)
from nowcasting_datamodel.models.forecast import ForecastSQL, ForecastValueSQL
from nowcasting_datamodel.models.gsp import GSPYieldSQL
from nowcasting_datamodel.read.read import get_location, get_model
from nowcasting_datamodel.read.read import get_location
from nowcasting_datamodel.read.read_models import get_model
from nowcasting_datamodel.read.read_metric import get_datetime_interval
from nowcasting_datamodel.save.update import change_forecast_value_to_latest

Expand Down
2 changes: 1 addition & 1 deletion nowcasting_datamodel/models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from nowcasting_datamodel.read.read import (
get_latest_input_data_last_updated,
get_location,
get_model,
)
from nowcasting_datamodel.read.read_models import get_model

logger = logging.getLogger()

Expand Down
41 changes: 0 additions & 41 deletions nowcasting_datamodel/read/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,47 +759,6 @@ def get_all_locations(session: Session, gsp_ids: List[int] = None) -> List[Locat
return locations


def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL:
"""
Get model object from name and version
:param session: database session
:param name: name of the model
:param version: version of the model
return: Model object
"""

# start main query
query = session.query(MLModelSQL)

# filter on gsp_id
query = query.filter(MLModelSQL.name == name)
if version is not None:
query = query.filter(MLModelSQL.version == version)

# gets the latest version
query = query.order_by(MLModelSQL.version.desc())

# get all results
models = query.all()

if len(models) == 0:
logger.debug(
f"Model for name {name} and version {version} does not exist so going to add it"
)

model = MLModelSQL(name=name, version=version)
session.add(model)
session.commit()

else:
model = models[0]

return model


def get_pv_system(
session: Session, pv_system_id: int, provider: Optional[str] = "pvoutput.org"
) -> PVSystemSQL:
Expand Down
45 changes: 45 additions & 0 deletions nowcasting_datamodel/read/read_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from datetime import datetime
from typing import Optional

from sqlalchemy.orm import Session

from nowcasting_datamodel.models import MLModelSQL
from nowcasting_datamodel.models.forecast import ForecastSQL
from nowcasting_datamodel.models.models import MLModelSQL
from nowcasting_datamodel.read.read import logger


def get_models(
Expand Down Expand Up @@ -37,3 +41,44 @@ def get_models(
models = query.all()

return models


def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL:
"""
Get model object from name and version
:param session: database session
:param name: name of the model
:param version: version of the model
return: Model object
"""

# start main query
query = session.query(MLModelSQL)

# filter on gsp_id
query = query.filter(MLModelSQL.name == name)
if version is not None:
query = query.filter(MLModelSQL.version == version)

# gets the latest version
query = query.order_by(MLModelSQL.version.desc())

# get all results
models = query.all()

if len(models) == 0:
logger.debug(
f"Model for name {name} and version {version} does not exist so going to add it"
)

model = MLModelSQL(name=name, version=version)
session.add(model)
session.commit()

else:
model = models[0]

return model
12 changes: 0 additions & 12 deletions tests/read/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from nowcasting_datamodel.models import (
InputDataLastUpdatedSQL,
LocationSQL,
MLModel,
PVSystem,
Status,
national_gb_label,
Expand All @@ -36,7 +35,6 @@
get_latest_national_forecast,
get_latest_status,
get_location,
get_model,
get_pv_system,
update_latest_input_data_last_updated,
)
Expand Down Expand Up @@ -70,16 +68,6 @@ def test_get_national_location(db_session):
_ = get_location(session=db_session, gsp_id=0, label="test_label")


def test_get_model(db_session):
model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9")
model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9")

assert model_read_1.name == model_read_2.name
assert model_read_1.version == model_read_2.version

_ = MLModel.from_orm(model_read_2)


def test_get_forecast(db_session, forecasts):
forecast_read = get_latest_forecast(session=db_session)

Expand Down
14 changes: 12 additions & 2 deletions tests/read/test_read_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from datetime import datetime, timedelta, timezone

from nowcasting_datamodel.read.read_models import get_models
from nowcasting_datamodel.read.read import get_model
from nowcasting_datamodel.models import MLModel
from nowcasting_datamodel.read.read_models import get_models, get_model


def test_get_models(db_session):
Expand Down Expand Up @@ -44,3 +44,13 @@ def testget_models_after_created(db_session):
forecast_created_utc=datetime.now(tz=timezone.utc) + timedelta(days=1),
)
assert len(models) == 0


def test_get_model(db_session):
model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9")
model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9")

assert model_read_1.name == model_read_2.name
assert model_read_1.version == model_read_2.version

_ = MLModel.from_orm(model_read_2)
2 changes: 1 addition & 1 deletion tests/save/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
make_fake_forecast_value,
make_fake_forecasts,
)
from nowcasting_datamodel.read.read import get_model
from nowcasting_datamodel.read.read_models import get_model
from nowcasting_datamodel.models import ForecastValueSevenDaysSQL
from nowcasting_datamodel.models.forecast import (
ForecastSQL,
Expand Down

0 comments on commit 6d0a919

Please sign in to comment.