From da049e56d95c4d73de334d9ff41457c8962fef36 Mon Sep 17 00:00:00 2001 From: madtoinou <32447896+madtoinou@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:51:25 +0100 Subject: [PATCH] Fix/exp smooth constructor args (#2059) * feat: adding support for constructor kwargs * feat: adding tests * fix: udpated representation test for ExponentialSmoothing model * update changelog.md --------- Co-authored-by: dennisbader --- CHANGELOG.md | 1 + .../forecasting/exponential_smoothing.py | 12 ++- .../forecasting/test_exponential_smoothing.py | 79 +++++++++++++++---- .../test_local_forecasting_models.py | 2 +- 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d2327303c..0fa67fd3de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader). - Other improvements: - Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader). + - Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou). **Fixed** - Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou). diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index dda34b6992..a847b155d5 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -3,7 +3,7 @@ --------------------- """ -from typing import Optional +from typing import Any, Dict, Optional import numpy as np import statsmodels.tsa.holtwinters as hw @@ -24,7 +24,8 @@ def __init__( seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE, seasonal_periods: Optional[int] = None, random_state: int = 0, - **fit_kwargs, + kwargs: Optional[Dict[str, Any]] = None, + **fit_kwargs ): """Exponential Smoothing @@ -61,6 +62,11 @@ def __init__( seasonal_periods The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily data with a weekly cycle. If not set, inferred from frequency of the series. + kwargs + Some optional keyword arguments that will be used to call + :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`. + See `the documentation + `_. fit_kwargs Some optional keyword arguments that will be used to call :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. @@ -91,6 +97,7 @@ def __init__( self.seasonal = seasonal self.infer_seasonal_periods = seasonal_periods is None self.seasonal_periods = seasonal_periods + self.constructor_kwargs = dict() if kwargs is None else kwargs self.fit_kwargs = fit_kwargs self.model = None np.random.seed(random_state) @@ -120,6 +127,7 @@ def fit(self, series: TimeSeries): seasonal_periods=seasonal_periods_param, freq=series.freq if series.has_datetime_index else None, dates=series.time_index if series.has_datetime_index else None, + **self.constructor_kwargs ) hw_results = hw_model.fit(**self.fit_kwargs) self.model = hw_results diff --git a/darts/tests/models/forecasting/test_exponential_smoothing.py b/darts/tests/models/forecasting/test_exponential_smoothing.py index 173a2ba508..63b494ae44 100644 --- a/darts/tests/models/forecasting/test_exponential_smoothing.py +++ b/darts/tests/models/forecasting/test_exponential_smoothing.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from darts import TimeSeries from darts.models import ExponentialSmoothing @@ -6,36 +7,86 @@ class TestExponentialSmoothing: - def helper_test_seasonality_inference(self, freq_string, expected_seasonal_periods): - series = tg.sine_timeseries(length=200, freq=freq_string) - model = ExponentialSmoothing() - model.fit(series) - assert model.seasonal_periods == expected_seasonal_periods + series = tg.sine_timeseries(length=100, freq="H") - def test_seasonality_inference(self): - - # test `seasonal_periods` inference for datetime indices - freq_str_seasonality_periods_tuples = [ + @pytest.mark.parametrize( + "freq_string,expected_seasonal_periods", + [ ("D", 7), ("H", 24), ("M", 12), ("W", 52), ("Q", 4), ("B", 5), - ] - for tuple in freq_str_seasonality_periods_tuples: - self.helper_test_seasonality_inference(*tuple) + ], + ) + def test_seasonality_inference( + self, freq_string: str, expected_seasonal_periods: int + ): + series = tg.sine_timeseries(length=200, freq=freq_string) + model = ExponentialSmoothing() + model.fit(series) + assert model.seasonal_periods == expected_seasonal_periods - # test default selection for integer index + def test_default_parameters(self): + """Test default selection for integer index""" series = TimeSeries.from_values(np.arange(1, 30, 1)) model = ExponentialSmoothing() model.fit(series) assert model.seasonal_periods == 12 - # test whether a model that inferred a seasonality period before will do it again for a new series + def test_multiple_fit(self): + """Test whether a model that inferred a seasonality period before will do it again for a new series""" series1 = tg.sine_timeseries(length=100, freq="M") series2 = tg.sine_timeseries(length=100, freq="D") model = ExponentialSmoothing() model.fit(series1) model.fit(series2) assert model.seasonal_periods == 7 + + def test_constructor_kwargs(self): + """Using kwargs to pass additional parameters to the constructor""" + constructor_kwargs = { + "initialization_method": "known", + "initial_level": 0.5, + "initial_trend": 0.2, + "initial_seasonal": np.arange(1, 25), + } + model = ExponentialSmoothing(kwargs=constructor_kwargs) + model.fit(self.series) + # must be checked separately, name is not consistent + np.testing.assert_array_almost_equal( + model.model.model.params["initial_seasons"], + constructor_kwargs["initial_seasonal"], + ) + for param_name in ["initial_level", "initial_trend"]: + assert ( + model.model.model.params[param_name] == constructor_kwargs[param_name] + ) + + def test_fit_kwargs(self): + """Using kwargs to pass additional parameters to the fit()""" + # using default optimization method + model = ExponentialSmoothing() + model.fit(self.series) + assert model.fit_kwargs == {} + pred = model.predict(n=2) + + model_bis = ExponentialSmoothing() + model_bis.fit(self.series) + assert model_bis.fit_kwargs == {} + pred_bis = model_bis.predict(n=2) + + # two methods with the same parameters should yield the same forecasts + assert pred.time_index.equals(pred_bis.time_index) + np.testing.assert_array_almost_equal(pred.values(), pred_bis.values()) + + # change optimization method + model_ls = ExponentialSmoothing(method="least_squares") + model_ls.fit(self.series) + assert model_ls.fit_kwargs == {"method": "least_squares"} + pred_ls = model_ls.predict(n=2) + + # forecasts should be slightly different + assert pred.time_index.equals(pred_ls.time_index) + assert all(np.not_equal(pred.values(), pred_ls.values())) diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 557883b4c0..f3ac21d40d 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -651,7 +651,7 @@ def test_model_str_call(self, config): ( ExponentialSmoothing(), "ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, " - + "seasonal_periods=None, random_state=0)", + + "seasonal_periods=None, random_state=0, kwargs=None)", ), # no params changed ( ARIMA(1, 1, 1),