Skip to content

Commit

Permalink
add get models + tests (#279)
Browse files Browse the repository at this point in the history
* add get models + tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
peterdudfield and pre-commit-ci[bot] authored May 22, 2024
1 parent 4f51d06 commit b30a662
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 56 deletions.
3 changes: 2 additions & 1 deletion nowcasting_datamodel/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
)
from nowcasting_datamodel.models.forecast import ForecastSQL, ForecastValueSQL
from nowcasting_datamodel.models.gsp import GSPYieldSQL
from nowcasting_datamodel.read.read import get_location, get_model
from nowcasting_datamodel.read.read import get_location
from nowcasting_datamodel.read.read_metric import get_datetime_interval
from nowcasting_datamodel.read.read_models import get_model
from nowcasting_datamodel.save.update import change_forecast_value_to_latest

# 2 days in the past + 8 hours forward at 30 mins interval
Expand Down
2 changes: 1 addition & 1 deletion nowcasting_datamodel/models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from nowcasting_datamodel.read.read import (
get_latest_input_data_last_updated,
get_location,
get_model,
)
from nowcasting_datamodel.read.read_models import get_model

logger = logging.getLogger()

Expand Down
41 changes: 0 additions & 41 deletions nowcasting_datamodel/read/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,47 +759,6 @@ def get_all_locations(session: Session, gsp_ids: List[int] = None) -> List[Locat
return locations


def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL:
"""
Get model object from name and version
:param session: database session
:param name: name of the model
:param version: version of the model
return: Model object
"""

# start main query
query = session.query(MLModelSQL)

# filter on gsp_id
query = query.filter(MLModelSQL.name == name)
if version is not None:
query = query.filter(MLModelSQL.version == version)

# gets the latest version
query = query.order_by(MLModelSQL.version.desc())

# get all results
models = query.all()

if len(models) == 0:
logger.debug(
f"Model for name {name} and version {version} does not exist so going to add it"
)

model = MLModelSQL(name=name, version=version)
session.add(model)
session.commit()

else:
model = models[0]

return model


def get_pv_system(
session: Session, pv_system_id: int, provider: Optional[str] = "pvoutput.org"
) -> PVSystemSQL:
Expand Down
83 changes: 83 additions & 0 deletions nowcasting_datamodel/read/read_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
""" Read functions for models"""

from datetime import datetime
from typing import Optional

from sqlalchemy.orm import Session

from nowcasting_datamodel.models import MLModelSQL
from nowcasting_datamodel.models.forecast import ForecastSQL
from nowcasting_datamodel.read.read import logger


def get_models(
session, with_forecasts: Optional[bool] = False, forecast_created_utc: Optional[datetime] = None
) -> MLModelSQL:
"""
Get all models from the database, distinct on name
Can also filter if the model has forecasts, and if they are made after a certain time
:param session: sql session
:param with_forecasts: only get models with forecasts
:param forecast_created_utc: only look at forecast created after this time
:return:
"""

query = session.query(MLModelSQL)
query = query.distinct(MLModelSQL.name)

if with_forecasts:
query = query.join(ForecastSQL)

if forecast_created_utc:
query = query.filter(ForecastSQL.created_utc > forecast_created_utc)

# order by
query = query.order_by(MLModelSQL.name)

# get all results
models = query.all()

return models


def get_model(session: Session, name: str, version: Optional[str] = None) -> MLModelSQL:
"""
Get model object from name and version
:param session: database session
:param name: name of the model
:param version: version of the model
return: Model object
"""

# start main query
query = session.query(MLModelSQL)

# filter on gsp_id
query = query.filter(MLModelSQL.name == name)
if version is not None:
query = query.filter(MLModelSQL.version == version)

# gets the latest version
query = query.order_by(MLModelSQL.version.desc())

# get all results
models = query.all()

if len(models) == 0:
logger.debug(
f"Model for name {name} and version {version} does not exist so going to add it"
)

model = MLModelSQL(name=name, version=version)
session.add(model)
session.commit()

else:
model = models[0]

return model
12 changes: 0 additions & 12 deletions tests/read/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from nowcasting_datamodel.models import (
InputDataLastUpdatedSQL,
LocationSQL,
MLModel,
PVSystem,
Status,
national_gb_label,
Expand All @@ -36,7 +35,6 @@
get_latest_national_forecast,
get_latest_status,
get_location,
get_model,
get_pv_system,
update_latest_input_data_last_updated,
)
Expand Down Expand Up @@ -70,16 +68,6 @@ def test_get_national_location(db_session):
_ = get_location(session=db_session, gsp_id=0, label="test_label")


def test_get_model(db_session):
model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9")
model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9")

assert model_read_1.name == model_read_2.name
assert model_read_1.version == model_read_2.version

_ = MLModel.from_orm(model_read_2)


def test_get_forecast(db_session, forecasts):
forecast_read = get_latest_forecast(session=db_session)

Expand Down
56 changes: 56 additions & 0 deletions tests/read/test_read_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from nowcasting_datamodel.fake import make_fake_forecast

from datetime import datetime, timedelta, timezone

from nowcasting_datamodel.models import MLModel
from nowcasting_datamodel.read.read_models import get_models, get_model


def test_get_models(db_session):

_ = make_fake_forecast(session=db_session, gsp_id=1)

models = get_models(session=db_session)
assert len(models) == 1


def test_get_models_no_models(db_session):
models = get_models(session=db_session)
assert len(models) == 0


def test_get_models_no_forecasts(db_session):
get_model(session=db_session, name="test")

models = get_models(session=db_session, with_forecasts=True)
assert len(models) == 0


def test_get_models_multiple_models(db_session):

_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_1")
_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_2")

models = get_models(session=db_session)
assert len(models) == 2


def testget_models_after_created(db_session):
_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_1")

models = get_models(
session=db_session,
with_forecasts=True,
forecast_created_utc=datetime.now(tz=timezone.utc) + timedelta(days=1),
)
assert len(models) == 0


def test_get_model(db_session):
model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9")
model_read_2 = get_model(session=db_session, name="test_name", version="9.9.9")

assert model_read_1.name == model_read_2.name
assert model_read_1.version == model_read_2.version

_ = MLModel.from_orm(model_read_2)
2 changes: 1 addition & 1 deletion tests/save/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
make_fake_forecast_value,
make_fake_forecasts,
)
from nowcasting_datamodel.read.read import get_model
from nowcasting_datamodel.read.read_models import get_model
from nowcasting_datamodel.models import ForecastValueSevenDaysSQL
from nowcasting_datamodel.models.forecast import (
ForecastSQL,
Expand Down

0 comments on commit b30a662

Please sign in to comment.