Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imp/Add new model StatsForecastAutoTBATS #2611

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added `data_transformers` argument to `historical_forecasts`, `backtest`, `residuals`, and `gridsearch` that allow to automatically apply `DataTransformer` and/or `Pipeline` to the input series without data-leakage (fit on historic window of input series, transform the input series, and inverse transform the forecasts). [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) and [Jan Fidor](https://github.com/JanFidor)
- Added `series_idx` argument to `DataTransformer` that allows users to use only a subset of the transformers when `global_fit=False` and severals series are used. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou)
- Updated the Documentation URL of `Statsforecast` models. [#2610](https://github.com/unit8co/darts/pull/2610) by [He Weilin](https://github.com/cnhwl).
- New model: `StatsForecastAutoTBATS`. This model offers the [AutoTBATS](https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats) model from Nixtla's `statsforecasts` library. [#2611](https://github.com/unit8co/darts/pull/2611) by [He Weilin](https://github.com/cnhwl).

**Fixed**

Expand All @@ -29,6 +30,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
### For developers of the library:

**Improved**

- Improvements to CI/CD: [#2584](https://github.com/unit8co/darts/pull/2584) by [Dennis Bader](https://github.com/dennisbader).
- updated all workflows with most recent action versions
- improved caching across `master` branch and its children
Expand Down Expand Up @@ -83,6 +85,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed the comment of `scorers_are_univariate` in class `AnomalyModel`. [#2452](https://github.com/unit8co/darts/pull/2542) by [He Weilin](https://github.com/cnhwl).

**Dependencies**

- Bumped release requirements versions for jupyterlab and dependencies: [#2515](https://github.com/unit8co/darts/pull/2515) by [Dennis Bader](https://github.com/dennisbader).
- Bumped `ipython` from 8.10.0 to 8.18.1
- Bumped `ipykernel` from 5.3.4 to 6.29.5
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ on bringing more models and features.
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | [TBATS paper](https://robjhyndman.com/papers/ComplexSeasonality.pdf) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | [Theta](https://robjhyndman.com/papers/Theta.pdf) & [4 Theta](https://github.com/Mcompetitions/M4-methods/blob/master/4Theta%20method.R) | ✅ 🔴 | 🔴 🔴 🔴 | 🔴 🔴 | 🔴 |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | [Nixtla's statsforecast](https://github.com/Nixtla/statsforecast) | ✅ 🔴 | 🔴 🔴 🔴 | ✅ 🔴 | 🔴 |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | [Prophet repo](https://github.com/facebook/prophet) | ✅ 🔴 | 🔴 ✅ 🔴 | ✅ 🔴 | 🔴 |
Expand Down
3 changes: 3 additions & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES
from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS
from darts.models.forecasting.sf_auto_tbats import StatsForecastAutoTBATS
from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta

except ImportError:
Expand All @@ -102,6 +103,7 @@
StatsForecastAutoCES = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoETS = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTheta = NotImportedModule(module_name="StatsForecast", warn=False)
StatsForecastAutoTBATS = NotImportedModule(module_name="StatsForecast", warn=False)

try:
from darts.models.forecasting.xgboost import XGBModel
Expand Down Expand Up @@ -159,6 +161,7 @@
"StatsForecastAutoCES",
"StatsForecastAutoETS",
"StatsForecastAutoTheta",
"StatsForecastAutoTBATS",
"XGBModel",
"GaussianProcessFilter",
"KalmanFilter",
Expand Down
1 change: 1 addition & 0 deletions darts/models/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- :class:`~darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES`
- :class:`~darts.models.forecasting.tbats_model.BATS`
- :class:`~darts.models.forecasting.tbats_model.TBATS`
- :class:`~darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS`
- :class:`~darts.models.forecasting.theta.Theta`
- :class:`~darts.models.forecasting.theta.FourTheta`
- :class:`~darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta`
Expand Down
104 changes: 104 additions & 0 deletions darts/models/forecasting/sf_auto_tbats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
StatsForecastAutoTBATS
-----------
"""

from statsforecast.models import AutoTBATS as SFAutoTBATS

from darts import TimeSeries
from darts.models.components.statsforecast_utils import (
create_normal_samples,
one_sigma_rule,
unpack_sf_dict,
)
from darts.models.forecasting.forecasting_model import LocalForecastingModel


class StatsForecastAutoTBATS(LocalForecastingModel):
def __init__(self, *autoTBATS_args, **autoTBATS_kwargs):
"""Auto-TBATS based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.

Automatically selects the best TBATS model from all feasible combinations of the parameters `use_boxcox`,
`use_trend`, `use_damped_trend`, and `use_arma_errors`. Selection is made using the AIC.
Default value for `use_arma_errors` is True since this enables the evaluation of models with
and without ARMA errors.
<https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=f3de25596ab60ef0e886366826bf58a02b35a44f>
<https://doi.org/10.4225/03/589299681de3d>

We refer to the `statsforecast AutoTBATS documentation
<https://nixtlaverse.nixtla.io/statsforecast/src/core/models.html#autotbats>`_
for the exhaustive documentation of the arguments.

Parameters
----------
autoTBATS_args
Positional arguments for ``statsforecasts.models.AutoTBATS``.
autoTBATS_kwargs
Keyword arguments for ``statsforecasts.models.AutoTBATS``.

Examples
--------
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastAutoTBATS
>>> series = AirPassengersDataset().load()
>>> # define StatsForecastAutoTBATS parameters
>>> model = StatsForecastAutoTBATS(season_length=12)
>>> model.fit(series)
>>> pred = model.predict(6)
>>> pred.values()
array([[450.79653684],
[472.09265790],
[497.76948306],
[510.74927369],
[520.92224557],
[570.33881522]])
"""
super().__init__()
self.model = SFAutoTBATS(*autoTBATS_args, **autoTBATS_kwargs)

def fit(self, series: TimeSeries):
super().fit(series)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
)
return self

def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
show_warnings: bool = True,
):
super().predict(n, num_samples)
forecast_dict = self.model.predict(
h=n,
level=(one_sigma_rule,), # ask one std for the confidence interval.
)

mu, std = unpack_sf_dict(forecast_dict)
if num_samples > 1:
samples = create_normal_samples(mu, std, num_samples, n)
else:
samples = mu

return self._build_forecast_series(samples)

@property
def supports_multivariate(self) -> bool:
return False

@property
def min_train_series_length(self) -> int:
return 10

@property
def _supports_range_index(self) -> bool:
return True

@property
def supports_probabilistic_prediction(self) -> bool:
return True
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
StatsForecastAutoARIMA,
StatsForecastAutoCES,
StatsForecastAutoETS,
StatsForecastAutoTBATS,
StatsForecastAutoTheta,
Theta,
)
Expand All @@ -57,6 +58,7 @@
(StatsForecastAutoTheta(season_length=12), 5.5),
(StatsForecastAutoCES(season_length=12, model="Z"), 7.3),
(StatsForecastAutoETS(season_length=12, model="AAZ"), 7.3),
(StatsForecastAutoTBATS(season_length=12), 10),
(Croston(version="classic"), 23),
(Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 23),
(Theta(), 11),
Expand Down
13 changes: 13 additions & 0 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LightGBMModel,
LinearRegressionModel,
NotImportedModule,
StatsForecastAutoTBATS,
XGBModel,
)
from darts.tests.conftest import TORCH_AVAILABLE, tfm_kwargs
Expand Down Expand Up @@ -92,6 +93,18 @@
0.04,
0.04,
),
(
StatsForecastAutoTBATS,
{
"season_length": 1,
"use_trend": False,
"use_damped_trend": False,
"use_boxcox": True,
"use_arma_errors": False,
},
0.04,
0.04,
),
]

xgb_test_params = {
Expand Down
1 change: 1 addition & 0 deletions docs/userguide/covariates.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ GFMs are models that can be trained on multiple target (and covariate) time seri
| [StatsforecastAutoETS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ets.html#darts.models.forecasting.sf_auto_ets.StatsForecastAutoETS) | | ✅ | |
| [StatsforecastAutoCES](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_ces.html#darts.models.forecasting.sf_auto_ces.StatsForecastAutoCES) | | | |
| [BATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.BATS) and [TBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.tbats_model.html#darts.models.forecasting.tbats_model.TBATS) | | | |
| [StatsForecastAutoTBATS](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_tbats.html#darts.models.forecasting.sf_auto_tbats.StatsForecastAutoTBATS) | | | |
| [Theta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.Theta) and [FourTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.theta.html#darts.models.forecasting.theta.FourTheta) | | | |
| [StatsForecastAutoTheta](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.sf_auto_theta.html#darts.models.forecasting.sf_auto_theta.StatsForecastAutoTheta) | | | |
| [Prophet](https://unit8co.github.io/darts/generated_api/darts.models.forecasting.prophet_model.html#darts.models.forecasting.prophet_model.Prophet) | | ✅ | |
Expand Down
Loading