Skip to content

Commit

Permalink
add get models + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed May 22, 2024
1 parent 4f51d06 commit 2451fa5
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
38 changes: 38 additions & 0 deletions nowcasting_datamodel/read/read_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
""" Read functions for models"""

from typing import Optional
from datetime import datetime
from nowcasting_datamodel.models.models import MLModelSQL
from nowcasting_datamodel.models.forecast import ForecastSQL


def get_models(
session, with_forecasts: Optional[bool] = False, forecast_created_utc: Optional[datetime] = None
) -> MLModelSQL:
"""
Get all models from the database, distinct on name
Can also filter if the model has forecasts, and if they are made after a certain time
:param session: sql session
:param with_forecasts: only get models with forecasts
:param forecast_created_utc: only look at forecast created after this time
:return:
"""

query = session.query(MLModelSQL)
query = query.distinct(MLModelSQL.name)

if with_forecasts:
query = query.join(ForecastSQL)

if forecast_created_utc:
query = query.filter(ForecastSQL.created_utc > forecast_created_utc)

# order by
query = query.order_by(MLModelSQL.name)

# get all results
models = query.all()

return models
46 changes: 46 additions & 0 deletions tests/read/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from nowcasting_datamodel.fake import make_fake_forecast

from datetime import datetime, timedelta, timezone

from nowcasting_datamodel.read.read_models import get_models
from nowcasting_datamodel.read.read import get_model


def test_get_models(db_session):

_ = make_fake_forecast(session=db_session, gsp_id=1)

models = get_models(session=db_session)
assert len(models) == 1


def test_get_models_no_models(db_session):
models = get_models(session=db_session)
assert len(models) == 0


def test_get_models_no_forecasts(db_session):
get_model(session=db_session, name="test")

models = get_models(session=db_session, with_forecasts=True)
assert len(models) == 0


def test_get_models_multiple_models(db_session):

_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_1")
_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_2")

models = get_models(session=db_session)
assert len(models) == 2


def testget_models_after_created(db_session):
_ = make_fake_forecast(session=db_session, gsp_id=1, model_name="test_1")

models = get_models(
session=db_session,
with_forecasts=True,
forecast_created_utc=datetime.now(tz=timezone.utc) + timedelta(days=1),
)
assert len(models) == 0

0 comments on commit 2451fa5

Please sign in to comment.