From 9f8cee9cb0b8f0ffa4517c2af5d69ace81f9e5a8 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 16 Nov 2023 10:37:02 +0100 Subject: [PATCH] uddate fit/predict wrappers --- darts/models/forecasting/forecasting_model.py | 59 ++++++++----------- .../models/forecasting/test_backtesting.py | 8 +-- .../forecasting/test_historical_forecasts.py | 8 +-- 3 files changed, 33 insertions(+), 42 deletions(-) diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index b245aa10b3..30d2c183d6 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -321,18 +321,17 @@ def _fit_wrapper( ): supported_params = inspect.signature(self.fit).parameters kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params} - if self.supports_past_covariates: - kwargs_["past_covariates"] = past_covariates - elif past_covariates is not None: - raise_log( - ValueError("Model cannot be fitted with `past_covariates`."), logger - ) - if self.supports_future_covariates: - kwargs_["future_covariates"] = future_covariates - elif future_covariates is not None: - raise_log( - ValueError("Model cannot be fitted with `future_covariates`."), logger - ) + + # handle past and future covariates based on model support + for covs, name in zip([past_covariates, future_covariates], ["past", "future"]): + covs_name = f"{name}_covariates" + if getattr(self, f"supports_{covs_name}"): + kwargs_[covs_name] = covs + elif covs is not None: + raise_log( + ValueError(f"Model cannot be fit/trained with `{covs_name}`."), + logger, + ) self.fit(series, **kwargs_) def _predict_wrapper( @@ -341,29 +340,21 @@ def _predict_wrapper( **kwargs, ) -> Union[TimeSeries, Sequence[TimeSeries]]: supported_params = set(inspect.signature(self.predict).parameters) + # if predict() accepts covariates, the model might not support them at inference - if "past_covariates" in kwargs and not self.supports_past_covariates: - if kwargs["past_covariates"] is None: - supported_params = supported_params - {"past_covariates"} - else: - raise_log( - ValueError( - "Model does not support `past_covariates` at inference, either remove this covariate or " - "fit the model with this covariate." - ), - logger, - ) - if "future_covariates" in kwargs and not self.supports_future_covariates: - if kwargs["future_covariates"] is None: - supported_params = supported_params - {"future_covariates"} - else: - raise_log( - ValueError( - "Model does not support `future_covariates` at inference, either remove this covariate or " - "fit the model with this covariate." - ), - logger, - ) + for covs_name in ["past_covariates", "future_covariates"]: + if covs_name in kwargs and not getattr(self, f"supports_{covs_name}"): + if kwargs[covs_name] is None: + supported_params = supported_params - {covs_name} + else: + raise_log( + ValueError( + f"Model prediction does not support `{covs_name}`, either because it " + f"does not support `{covs_name}` in general, or because it was fit/trained " + f"without using `{covs_name}`." + ), + logger, + ) kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params} return self.predict(n, **kwargs_) diff --git a/darts/tests/models/forecasting/test_backtesting.py b/darts/tests/models/forecasting/test_backtesting.py index 49de3305da..e54ca70d5d 100644 --- a/darts/tests/models/forecasting/test_backtesting.py +++ b/darts/tests/models/forecasting/test_backtesting.py @@ -476,13 +476,13 @@ def test_backtest_bad_covariates(self, model_cls): with pytest.raises(ValueError) as msg: model.backtest(series=series, past_covariates=series, **bt_kwargs) assert str(msg.value).startswith( - "Model cannot be fitted with `past_covariates`." + "Model cannot be fit/trained with `past_covariates`." ) if not model.supports_future_covariates: with pytest.raises(ValueError) as msg: model.backtest(series=series, future_covariates=series, **bt_kwargs) assert str(msg.value).startswith( - "Model cannot be fitted with `future_covariates`." + "Model cannot be fit/trained with `future_covariates`." ) def test_gridsearch(self): @@ -679,7 +679,7 @@ def test_gridsearch_bad_covariates(self, model_cls, parameters): **bt_kwargs ) assert str(msg.value).startswith( - "Model cannot be fitted with `past_covariates`." + "Model cannot be fit/trained with `past_covariates`." ) if not model.supports_future_covariates: with pytest.raises(ValueError) as msg: @@ -691,5 +691,5 @@ def test_gridsearch_bad_covariates(self, model_cls, parameters): **bt_kwargs ) assert str(msg.value).startswith( - "Model cannot be fitted with `future_covariates`." + "Model cannot be fit/trained with `future_covariates`." ) diff --git a/darts/tests/models/forecasting/test_historical_forecasts.py b/darts/tests/models/forecasting/test_historical_forecasts.py index 85bcf96baa..77907f3f1a 100644 --- a/darts/tests/models/forecasting/test_historical_forecasts.py +++ b/darts/tests/models/forecasting/test_historical_forecasts.py @@ -388,7 +388,7 @@ def test_historical_forecasts_transferrable_future_cov_local_models(self): retrain=False, ) assert str(msg.value).startswith( - "Model does not support `past_covariates` at inference" + "Model prediction does not support `past_covariates`" ) def test_historical_forecasts_future_cov_local_models(self): @@ -421,7 +421,7 @@ def test_historical_forecasts_future_cov_local_models(self): retrain=True, ) assert str(msg.value).startswith( - "Model cannot be fitted with `past_covariates`." + "Model cannot be fit/trained with `past_covariates`." ) def test_historical_forecasts_local_models(self): @@ -655,7 +655,7 @@ def test_historical_forecasts(self, config): retrain=True, ) assert str(msg.value).startswith( - "Model cannot be fitted with `past_covariates`." + "Model cannot be fit/trained with `past_covariates`." ) if not model.supports_future_covariates: @@ -666,7 +666,7 @@ def test_historical_forecasts(self, config): last_points_only=False, ) assert str(msg.value).startswith( - "Model cannot be fitted with `future_covariates`." + "Model cannot be fit/trained with `future_covariates`." ) def test_sanity_check_invalid_start(self):