Skip to content

Commit

Permalink
uddate fit/predict wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Nov 16, 2023
1 parent a5e21c3 commit 9f8cee9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 42 deletions.
59 changes: 25 additions & 34 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,17 @@ def _fit_wrapper(
):
supported_params = inspect.signature(self.fit).parameters
kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}
if self.supports_past_covariates:
kwargs_["past_covariates"] = past_covariates
elif past_covariates is not None:
raise_log(
ValueError("Model cannot be fitted with `past_covariates`."), logger
)
if self.supports_future_covariates:
kwargs_["future_covariates"] = future_covariates
elif future_covariates is not None:
raise_log(
ValueError("Model cannot be fitted with `future_covariates`."), logger
)

# handle past and future covariates based on model support
for covs, name in zip([past_covariates, future_covariates], ["past", "future"]):
covs_name = f"{name}_covariates"
if getattr(self, f"supports_{covs_name}"):
kwargs_[covs_name] = covs
elif covs is not None:
raise_log(
ValueError(f"Model cannot be fit/trained with `{covs_name}`."),
logger,
)
self.fit(series, **kwargs_)

def _predict_wrapper(
Expand All @@ -341,29 +340,21 @@ def _predict_wrapper(
**kwargs,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
supported_params = set(inspect.signature(self.predict).parameters)

# if predict() accepts covariates, the model might not support them at inference
if "past_covariates" in kwargs and not self.supports_past_covariates:
if kwargs["past_covariates"] is None:
supported_params = supported_params - {"past_covariates"}
else:
raise_log(
ValueError(
"Model does not support `past_covariates` at inference, either remove this covariate or "
"fit the model with this covariate."
),
logger,
)
if "future_covariates" in kwargs and not self.supports_future_covariates:
if kwargs["future_covariates"] is None:
supported_params = supported_params - {"future_covariates"}
else:
raise_log(
ValueError(
"Model does not support `future_covariates` at inference, either remove this covariate or "
"fit the model with this covariate."
),
logger,
)
for covs_name in ["past_covariates", "future_covariates"]:
if covs_name in kwargs and not getattr(self, f"supports_{covs_name}"):
if kwargs[covs_name] is None:
supported_params = supported_params - {covs_name}
else:
raise_log(
ValueError(
f"Model prediction does not support `{covs_name}`, either because it "
f"does not support `{covs_name}` in general, or because it was fit/trained "
f"without using `{covs_name}`."
),
logger,
)

kwargs_ = {k: v for k, v in kwargs.items() if k in supported_params}
return self.predict(n, **kwargs_)
Expand Down
8 changes: 4 additions & 4 deletions darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,13 @@ def test_backtest_bad_covariates(self, model_cls):
with pytest.raises(ValueError) as msg:
model.backtest(series=series, past_covariates=series, **bt_kwargs)
assert str(msg.value).startswith(
"Model cannot be fitted with `past_covariates`."
"Model cannot be fit/trained with `past_covariates`."
)
if not model.supports_future_covariates:
with pytest.raises(ValueError) as msg:
model.backtest(series=series, future_covariates=series, **bt_kwargs)
assert str(msg.value).startswith(
"Model cannot be fitted with `future_covariates`."
"Model cannot be fit/trained with `future_covariates`."
)

def test_gridsearch(self):
Expand Down Expand Up @@ -679,7 +679,7 @@ def test_gridsearch_bad_covariates(self, model_cls, parameters):
**bt_kwargs
)
assert str(msg.value).startswith(
"Model cannot be fitted with `past_covariates`."
"Model cannot be fit/trained with `past_covariates`."
)
if not model.supports_future_covariates:
with pytest.raises(ValueError) as msg:
Expand All @@ -691,5 +691,5 @@ def test_gridsearch_bad_covariates(self, model_cls, parameters):
**bt_kwargs
)
assert str(msg.value).startswith(
"Model cannot be fitted with `future_covariates`."
"Model cannot be fit/trained with `future_covariates`."
)
8 changes: 4 additions & 4 deletions darts/tests/models/forecasting/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_historical_forecasts_transferrable_future_cov_local_models(self):
retrain=False,
)
assert str(msg.value).startswith(
"Model does not support `past_covariates` at inference"
"Model prediction does not support `past_covariates`"
)

def test_historical_forecasts_future_cov_local_models(self):
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_historical_forecasts_future_cov_local_models(self):
retrain=True,
)
assert str(msg.value).startswith(
"Model cannot be fitted with `past_covariates`."
"Model cannot be fit/trained with `past_covariates`."
)

def test_historical_forecasts_local_models(self):
Expand Down Expand Up @@ -655,7 +655,7 @@ def test_historical_forecasts(self, config):
retrain=True,
)
assert str(msg.value).startswith(
"Model cannot be fitted with `past_covariates`."
"Model cannot be fit/trained with `past_covariates`."
)

if not model.supports_future_covariates:
Expand All @@ -666,7 +666,7 @@ def test_historical_forecasts(self, config):
last_points_only=False,
)
assert str(msg.value).startswith(
"Model cannot be fitted with `future_covariates`."
"Model cannot be fit/trained with `future_covariates`."
)

def test_sanity_check_invalid_start(self):
Expand Down

0 comments on commit 9f8cee9

Please sign in to comment.