Skip to content

Commit

Permalink
Fix/hfc opti reg prob (unit8co#2588)
Browse files Browse the repository at this point in the history
* fix: check that model is probabilistic when num samples is greater than 1 for optimized historical forecasts

* feat: update the tests accordingly

* update changelog

* fix: simplify the test

* fix: remove typo

* fix: ignoring a linting commit for git blame
  • Loading branch information
madtoinou authored Nov 11, 2024
1 parent 18e2e3f commit d909589
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
38cc6712a6f701703074a7a7c82ce0252fe869ee
# Fix last isort issues and update Black to 22.1.0
8158d3eaef9d9f6e04f219b029e306d1f1be46d5
# Change Python target-version to 3.9 and update Ruff to 0.7.2
18e2e3fd7d82d239ab24807fcc1033094ea09940
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Fixed**

- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug when using `num_samples > 1` with a deterministic regression model and the optimized `historical_forecasts()` method, an exception was not raised. [#2576](https://github.com/unit8co/darts/pull/2588) by [Antoine Madrona](https://github.com/madtoinou).

**Dependencies**

Expand Down
42 changes: 42 additions & 0 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,48 @@ def test_historical_forecast(self, mode):
)
assert len(result) == 21

def test_opti_historical_forecast_predict_checks(self):
"""
Verify that the sanity check implemented in ForecastingModel.predict are also defined for optimized historical
forecasts as it does not call this method
"""
model = self.models[1](lags=5)

msg_expected = (
"The model has not been fitted yet, and `retrain` is ``False``. Either call `fit()` before "
"`historical_forecasts()`, or set `retrain` to something different than ``False``."
)
# untrained model, optimized
with pytest.raises(ValueError) as err:
model.historical_forecasts(
series=self.sine_univariate1,
start=0.9,
forecast_horizon=1,
retrain=False,
enable_optimization=True,
verbose=False,
)
assert str(err.value) == msg_expected

model.fit(
series=self.sine_univariate1,
)
# deterministic model, num_samples > 1, optimized
with pytest.raises(ValueError) as err:
model.historical_forecasts(
series=self.sine_univariate1,
start=0.9,
forecast_horizon=1,
retrain=False,
enable_optimization=True,
num_samples=10,
verbose=False,
)
assert (
str(err.value)
== "`num_samples > 1` is only supported for probabilistic models."
)

@pytest.mark.parametrize(
"config",
[
Expand Down
7 changes: 7 additions & 0 deletions darts/utils/historical_forecasts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def _historical_forecasts_general_checks(model, series, kwargs):
logger,
)

# duplication of ForecastingModel.predict() check for the optimized historical forecasts implementations
if not model.supports_probabilistic_prediction and n.num_samples > 1:
raise_log(
ValueError("`num_samples > 1` is only supported for probabilistic models."),
logger,
)

# check direct likelihood parameter prediction before fitting a model
if n.predict_likelihood_parameters:
if not model.supports_likelihood_parameter_prediction:
Expand Down

0 comments on commit d909589

Please sign in to comment.