diff --git a/pvsite_datamodel/write/forecast.py b/pvsite_datamodel/write/forecast.py index 5de3c66..1ef6af8 100644 --- a/pvsite_datamodel/write/forecast.py +++ b/pvsite_datamodel/write/forecast.py @@ -26,6 +26,8 @@ def insert_forecast_values( :param session: sqlalchemy session for interacting with the database :param forecast_meta: Meta info about the forecast values :param forecast_values_df: dataframe with the data to insert + :param ml_model_name: name of the ML model used to generate the forecast + :param ml_model_version: version of the ML model used to generate the forecast """ forecast = ForecastSQL(**forecast_meta) @@ -36,7 +38,7 @@ def insert_forecast_values( if (ml_model_name is not None) and (ml_model_version is not None): ml_model = get_or_create_model(session, ml_model_name, ml_model_version) - ml_model_uuid = ml_model.uuid + ml_model_uuid = ml_model.model_uuid else: ml_model_uuid = None diff --git a/tests/test_write.py b/tests/test_write.py index ad721e8..fa7d1d5 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -52,7 +52,13 @@ def test_insert_forecast_for_existing_site(self, db_session, forecast_valid_inpu assert ptypes.is_numeric_dtype(forecast_values_df["forecast_power_kw"]) assert ptypes.is_numeric_dtype(forecast_values_df["horizon_minutes"]) - insert_forecast_values(db_session, forecast_meta, forecast_values_df) + insert_forecast_values( + db_session, + forecast_meta, + forecast_values_df, + ml_model_name="test", + ml_model_version="0.0.0", + ) assert db_session.query(ForecastSQL).count() == 1 assert db_session.query(ForecastValueSQL).count() == 10