From 772d705c9d69b8cce58756e306683adce96361b4 Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Mon, 6 Nov 2023 11:06:52 +0100 Subject: [PATCH 1/4] Feat/tz aware dta (#2054) * add tests for datetime_attribute_timeseries and holidays * add tz to encoders * add tests * add timezone to some timeseries methods * update changelog and model docs * apply suggestions from PR review --- CHANGELOG.md | 2 + darts/dataprocessing/encoders/encoders.py | 80 ++++++++++++- darts/models/forecasting/arima.py | 3 +- darts/models/forecasting/auto_arima.py | 3 +- darts/models/forecasting/block_rnn_model.py | 3 +- darts/models/forecasting/catboost_model.py | 3 +- darts/models/forecasting/croston.py | 3 +- darts/models/forecasting/dlinear.py | 3 +- darts/models/forecasting/kalman_forecaster.py | 3 +- darts/models/forecasting/lgbm.py | 3 +- .../forecasting/linear_regression_model.py | 3 +- darts/models/forecasting/nbeats.py | 3 +- darts/models/forecasting/nhits.py | 3 +- darts/models/forecasting/nlinear.py | 3 +- darts/models/forecasting/prophet_model.py | 3 +- darts/models/forecasting/random_forest.py | 3 +- darts/models/forecasting/regression_model.py | 6 +- darts/models/forecasting/rnn_model.py | 3 +- darts/models/forecasting/sf_auto_arima.py | 3 +- darts/models/forecasting/sf_auto_ets.py | 3 +- darts/models/forecasting/tcn_model.py | 3 +- darts/models/forecasting/tft_model.py | 3 +- darts/models/forecasting/tide_model.py | 3 +- .../forecasting/torch_forecasting_model.py | 3 +- darts/models/forecasting/transformer_model.py | 3 +- darts/models/forecasting/varima.py | 3 +- darts/models/forecasting/xgboost.py | 3 +- .../dataprocessing/encoders/test_encoders.py | 39 ++++++ .../tests/utils/test_timeseries_generation.py | 113 ++++++++++++++++++ darts/timeseries.py | 30 ++++- darts/utils/timeseries_generation.py | 82 ++++++++++--- 31 files changed, 374 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11bc0b42d6..4a68470899 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat). - Improvements to Regression Models: - `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader). +- Other improvements: + - Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader). **Fixed** - Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou). diff --git a/darts/dataprocessing/encoders/encoders.py b/darts/dataprocessing/encoders/encoders.py index f590724710..09b3414593 100644 --- a/darts/dataprocessing/encoders/encoders.py +++ b/darts/dataprocessing/encoders/encoders.py @@ -28,6 +28,7 @@ input_chunk_length=24, output_chunk_length=12, attribute='month' + tz='CET' ) past_covariates_train = encoder.encode_train( @@ -75,6 +76,8 @@ attribute An attribute of `pd.DatetimeIndex`: see all available attributes in https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DatetimeIndex.html#pandas.DatetimeIndex. + tz + Optionally, convert the time zone naive index to a time zone `tz` before applying the encoder. * `CyclicTemporalEncoder` Adds cyclic pd.DatetimeIndex attribute information deriveed from `series.time_index`. Adds 2 columns, corresponding to sin and cos encodings, to uniquely describe the underlying attribute. @@ -84,6 +87,8 @@ An attribute of `pd.DatetimeIndex` that follows a cyclic pattern. One of ('month', 'day', 'weekday', 'dayofweek', 'day_of_week', 'hour', 'minute', 'second', 'microsecond', 'nanosecond', 'quarter', 'dayofyear', 'day_of_year', 'week', 'weekofyear', 'week_of_year'). + tz + Optionally, convert the time zone naive index to a time zone `tz` before applying the encoder. * `IntegerIndexEncoder` Adds the relative index positions as integer values (positions) derived from `series` time index. `series` can either have a pd.DatetimeIndex or an integer index. @@ -121,6 +126,7 @@ * 'position' for `IntegerIndexEncoder` * 'custom' for `CallableIndexEncoder` * 'transformer' for a transformer + * 'tz' for applying a time zone conversion * inner keys: covariates type * 'past' for past covariates @@ -142,7 +148,8 @@ 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [lambda idx: (idx.year - 1950) / 50]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET', } model = SomeTorchForecastingModel(..., add_encoders=add_encoders) @@ -184,6 +191,7 @@ VALID_TIME_PARAMS = [FUTURE, PAST] VALID_ENCODER_DTYPES = (str, Sequence) +TZ_KEYS = ["tz"] TRANSFORMER_KEYS = ["transformer"] VALID_TRANSFORMER_DTYPES = FittableDataTransformer INTEGER_INDEX_ATTRIBUTES = ["relative"] @@ -192,7 +200,12 @@ class CyclicTemporalEncoder(SingleEncoder): """`CyclicTemporalEncoder`: Cyclic encoding of time series datetime attributes.""" - def __init__(self, index_generator: CovariatesIndexGenerator, attribute: str): + def __init__( + self, + index_generator: CovariatesIndexGenerator, + attribute: str, + tz: Optional[str] = None, + ): """ Cyclic index encoding for `TimeSeries` that have a time index of type `pd.DatetimeIndex`. @@ -208,9 +221,12 @@ def __init__(self, index_generator: CovariatesIndexGenerator, attribute: str): https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DatetimeIndex.html#pandas.DatetimeIndex. For more information, check out :meth:`datetime_attribute_timeseries() ` + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__(index_generator) self.attribute = attribute + self.tz = tz def _encode( self, index: SupportedIndex, target_end: pd.Timestamp, dtype: np.dtype @@ -226,6 +242,7 @@ def _encode( self.base_component_name + self.attribute + "_sin", self.base_component_name + self.attribute + "_cos", ], + tz=self.tz, ) @property @@ -255,6 +272,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + tz: Optional[str] = None, ): """ Parameters @@ -280,6 +298,8 @@ def __init__( Optionally, a list of integers representing the past covariate lags. Accepts integer lag values <= -1. Only required for :class:`RegressionModel`. Corresponds to the lag values from parameter `lags_past_covariates` of :class:`RegressionModel`. + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__( index_generator=PastCovariatesIndexGenerator( @@ -288,6 +308,7 @@ def __init__( lags_covariates=lags_covariates, ), attribute=attribute, + tz=tz, ) @@ -300,6 +321,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + tz: Optional[str] = None, ): """ Parameters @@ -325,6 +347,8 @@ def __init__( Optionally, a list of integers representing the future covariate lags. Accepts all integer values. Only required for :class:`RegressionModel`. Corresponds to the lag values from parameter `lags_future_covariates` from :class:`RegressionModel`. + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__( index_generator=FutureCovariatesIndexGenerator( @@ -333,6 +357,7 @@ def __init__( lags_covariates=lags_covariates, ), attribute=attribute, + tz=tz, ) @@ -341,7 +366,12 @@ class DatetimeAttributeEncoder(SingleEncoder): Requires the underlying TimeSeries to have a pd.DatetimeIndex """ - def __init__(self, index_generator: CovariatesIndexGenerator, attribute: str): + def __init__( + self, + index_generator: CovariatesIndexGenerator, + attribute: str, + tz: Optional[str] = None, + ): """ Parameters ---------- @@ -355,9 +385,12 @@ def __init__(self, index_generator: CovariatesIndexGenerator, attribute: str): https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DatetimeIndex.html#pandas.DatetimeIndex. For more information, check out :meth:`datetime_attribute_timeseries() ` + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__(index_generator) self.attribute = attribute + self.tz = tz def _encode( self, index: SupportedIndex, target_end: pd.Timestamp, dtype: np.dtype @@ -369,6 +402,7 @@ def _encode( attribute=self.attribute, dtype=dtype, with_columns=self.base_component_name + self.attribute, + tz=self.tz, ) @property @@ -398,6 +432,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + tz: Optional[str] = None, ): """ Parameters @@ -423,6 +458,8 @@ def __init__( Optionally, a list of integers representing the past covariate lags. Accepts integer lag values <= -1. Only required for :class:`RegressionModel`. Corresponds to the lag values from parameter `lags_past_covariates` of :class:`RegressionModel`. + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__( index_generator=PastCovariatesIndexGenerator( @@ -431,6 +468,7 @@ def __init__( lags_covariates=lags_covariates, ), attribute=attribute, + tz=tz, ) @@ -443,6 +481,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + tz: Optional[str] = None, ): """ Parameters @@ -468,6 +507,8 @@ def __init__( Optionally, a list of integers representing the future covariate lags. Accepts all integer values. Only required for :class:`RegressionModel`. Corresponds to the lag values from parameter `lags_future_covariates` from :class:`RegressionModel`. + tz + Optionally, a time zone to convert the time index to before computing the attributes. """ super().__init__( index_generator=FutureCovariatesIndexGenerator( @@ -476,6 +517,7 @@ def __init__( lags_covariates=lags_covariates, ), attribute=attribute, + tz=tz, ) @@ -567,6 +609,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + **kwargs, ): """ Parameters @@ -610,6 +653,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + **kwargs, ): """ Parameters @@ -713,6 +757,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + **kwargs, ): """ Parameters @@ -759,6 +804,7 @@ def __init__( input_chunk_length: Optional[int] = None, output_chunk_length: Optional[int] = None, lags_covariates: Optional[List[int]] = None, + **kwargs, ): """ Parameters @@ -837,6 +883,9 @@ def __init__( ` such as Scaler() or BoxCox(). The transformers will be fitted on the training dataset when calling calling `model.fit()`. The training, validation and inference datasets are then transformed equally. + Supported time zone: + Optionally, apply a time zone conversion with keyword 'tz'. This converts the time zone-naive index to a + timezone `'tz'` before applying the `'cyclic'` or `'datetime_attribute'` temporal encoders. An example of a valid `add_encoders` dict for hourly data: @@ -849,7 +898,8 @@ def __init__( 'datetime_attribute': {'past': ['hour'], 'future': ['year', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [lambda idx: (idx.year - 1950) / 50]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET', } Tuples of `(encoder_id, attribute)` are extracted from `add_encoders` to instantiate the `SingleEncoder` @@ -1289,6 +1339,7 @@ def _setup_encoders(self, params: Dict) -> None: * params={'cyclic': {'past': ['month', 'dayofweek', ...], 'future': [same as for 'past']}} """ past_encoders, future_encoders = self._process_input_encoders(params) + tz = self._process_timezone(params) if not past_encoders and not future_encoders: return @@ -1299,6 +1350,7 @@ def _setup_encoders(self, params: Dict) -> None: input_chunk_length=self.input_chunk_length, output_chunk_length=self.output_chunk_length, lags_covariates=self.lags_past_covariates, + tz=tz, ) for enc_id, attr in past_encoders ] @@ -1308,6 +1360,7 @@ def _setup_encoders(self, params: Dict) -> None: input_chunk_length=self.input_chunk_length, output_chunk_length=self.output_chunk_length, lags_covariates=self.lags_future_covariates, + tz=tz, ) for enc_id, attr in future_encoders ] @@ -1369,7 +1422,9 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]: # check input for invalid encoder types invalid_encoders = [ - enc for enc in params if enc not in ENCODER_KEYS + TRANSFORMER_KEYS + enc + for enc in params + if enc not in ENCODER_KEYS + TZ_KEYS + TRANSFORMER_KEYS ] raise_if( len(invalid_encoders) > 0, @@ -1480,6 +1535,21 @@ def _process_input_transformer( ] return transformer, transform_past_mask, transform_future_mask + @staticmethod + def _process_timezone(params: Dict) -> Optional[str]: + """Processes input params used at model creation for time zone specification, and returns the time zone. + + Parameters + ---------- + params + Dict from parameter `add_encoders` (kwargs) used at model creation. Relevant parameters are: + * params={'tz': 'CET'} + """ + if not params: + return None + + return params.get(TZ_KEYS[0], None) + @property def requires_fit(self) -> bool: return any( diff --git a/darts/models/forecasting/arima.py b/darts/models/forecasting/arima.py index dbca628ca3..4891b33719 100644 --- a/darts/models/forecasting/arima.py +++ b/darts/models/forecasting/arima.py @@ -79,7 +79,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. diff --git a/darts/models/forecasting/auto_arima.py b/darts/models/forecasting/auto_arima.py index 4a9e7b6db2..a4f30e01d4 100644 --- a/darts/models/forecasting/auto_arima.py +++ b/darts/models/forecasting/auto_arima.py @@ -62,7 +62,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index c5eff40a83..5903adf9aa 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -249,7 +249,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/catboost_model.py b/darts/models/forecasting/catboost_model.py index b25b983a77..45ad8a1984 100644 --- a/darts/models/forecasting/catboost_model.py +++ b/darts/models/forecasting/catboost_model.py @@ -74,7 +74,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. likelihood diff --git a/darts/models/forecasting/croston.py b/darts/models/forecasting/croston.py index c5b6482f62..d71aaf2b29 100644 --- a/darts/models/forecasting/croston.py +++ b/darts/models/forecasting/croston.py @@ -66,7 +66,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. diff --git a/darts/models/forecasting/dlinear.py b/darts/models/forecasting/dlinear.py index 7b78c1a16b..62f41ee621 100644 --- a/darts/models/forecasting/dlinear.py +++ b/darts/models/forecasting/dlinear.py @@ -350,7 +350,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/kalman_forecaster.py b/darts/models/forecasting/kalman_forecaster.py index 7d1cc7ef93..00809a9fc0 100644 --- a/darts/models/forecasting/kalman_forecaster.py +++ b/darts/models/forecasting/kalman_forecaster.py @@ -75,7 +75,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. diff --git a/darts/models/forecasting/lgbm.py b/darts/models/forecasting/lgbm.py index f5ca44e288..c93927fb03 100644 --- a/darts/models/forecasting/lgbm.py +++ b/darts/models/forecasting/lgbm.py @@ -101,7 +101,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. likelihood diff --git a/darts/models/forecasting/linear_regression_model.py b/darts/models/forecasting/linear_regression_model.py index ed6ac9ba25..03773cca82 100644 --- a/darts/models/forecasting/linear_regression_model.py +++ b/darts/models/forecasting/linear_regression_model.py @@ -94,7 +94,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. likelihood diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index e5c06c3af7..24e789a201 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -671,7 +671,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index 7b63cd3ca6..69ba11eee3 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -607,7 +607,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/nlinear.py b/darts/models/forecasting/nlinear.py index 51fdb8a359..09e112f6f9 100644 --- a/darts/models/forecasting/nlinear.py +++ b/darts/models/forecasting/nlinear.py @@ -300,7 +300,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index 295e8286d3..cb0abaab4a 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -106,7 +106,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. cap diff --git a/darts/models/forecasting/random_forest.py b/darts/models/forecasting/random_forest.py index a5d91448ed..ad481ffc29 100644 --- a/darts/models/forecasting/random_forest.py +++ b/darts/models/forecasting/random_forest.py @@ -98,7 +98,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. n_estimators : int diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index 68a7b08d38..3f4d77f33c 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -137,7 +137,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. model @@ -1519,7 +1520,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. model diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index f26cd4a90b..22a4f25cec 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -333,7 +333,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/sf_auto_arima.py b/darts/models/forecasting/sf_auto_arima.py index 91cc12c0c0..c036a80b80 100644 --- a/darts/models/forecasting/sf_auto_arima.py +++ b/darts/models/forecasting/sf_auto_arima.py @@ -59,7 +59,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. autoarima_kwargs diff --git a/darts/models/forecasting/sf_auto_ets.py b/darts/models/forecasting/sf_auto_ets.py index d5971fa813..9636436e0a 100644 --- a/darts/models/forecasting/sf_auto_ets.py +++ b/darts/models/forecasting/sf_auto_ets.py @@ -64,7 +64,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. autoets_kwargs diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index 3b9795b033..076fd939df 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -372,7 +372,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index af89d05e8f..f77255c601 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -819,7 +819,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/tide_model.py b/darts/models/forecasting/tide_model.py index 6b506896b1..57704e21dd 100644 --- a/darts/models/forecasting/tide_model.py +++ b/darts/models/forecasting/tide_model.py @@ -498,7 +498,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 7cfb9cd5cb..8312808689 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -223,7 +223,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 5287bcfc71..6430b1d108 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -460,7 +460,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. random_state diff --git a/darts/models/forecasting/varima.py b/darts/models/forecasting/varima.py index ece88dba80..7e49df4fa7 100644 --- a/darts/models/forecasting/varima.py +++ b/darts/models/forecasting/varima.py @@ -71,7 +71,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'future': ['relative']}, 'custom': {'future': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. diff --git a/darts/models/forecasting/xgboost.py b/darts/models/forecasting/xgboost.py index 3b694c502f..417b00413f 100644 --- a/darts/models/forecasting/xgboost.py +++ b/darts/models/forecasting/xgboost.py @@ -120,7 +120,8 @@ def encode_year(idx): 'datetime_attribute': {'future': ['hour', 'dayofweek']}, 'position': {'past': ['relative'], 'future': ['relative']}, 'custom': {'past': [encode_year]}, - 'transformer': Scaler() + 'transformer': Scaler(), + 'tz': 'CET' } .. likelihood diff --git a/darts/tests/dataprocessing/encoders/test_encoders.py b/darts/tests/dataprocessing/encoders/test_encoders.py index dd3782dde1..41bc8c29dc 100644 --- a/darts/tests/dataprocessing/encoders/test_encoders.py +++ b/darts/tests/dataprocessing/encoders/test_encoders.py @@ -490,6 +490,7 @@ def extract_year(index): "future": [extract_month, extract_year], }, "transformer": Scaler(), + "tz": "CET", } # given `add_encoders` dict, we expect encoders to generate the following components comps_expected_past = pd.Index( @@ -1339,3 +1340,41 @@ def helper_test_encoder_single_inference( assert encoded == result for enc, enc_train_inf in zip(encoded, encoded_train_inf): assert enc == enc_train_inf[enc.time_index] + + def test_tz_conversion(self): + add_encoders = { + "cyclic": {"past": "hour", "future": "hour"}, + "datetime_attribute": {"past": "hour", "future": "hour"}, + } + encs = SequentialEncoder( + add_encoders=add_encoders, + input_chunk_length=12, + output_chunk_length=6, + takes_past_covariates=True, + takes_future_covariates=True, + ) + # convert to Central European Time (CET) + encs_tz = SequentialEncoder( + add_encoders=dict({"tz": "CET"}, **add_encoders), + input_chunk_length=12, + output_chunk_length=6, + takes_past_covariates=True, + takes_future_covariates=True, + ) + + ts = tg.linear_timeseries( + start=pd.Timestamp("2000-01-01 00:00:00"), length=48, freq="h" + ) + pc1, fc1 = encs.encode_train(ts) + pc2, fc2 = encs.encode_inference(n=6, target=ts) + + pc1_tz, fc1_tz = encs_tz.encode_train(ts) + pc2_tz, fc2_tz = encs_tz.encode_inference(n=6, target=ts) + + for vals, vals_tz in zip( + [pc1, pc2, fc1, fc2], [pc1_tz, pc2_tz, fc1_tz, fc2_tz] + ): + # CET is +1 hour compared to UTC, so we shift by 1 + np.testing.assert_array_almost_equal( + np.roll(vals.values(), -1, axis=0)[:-1], vals_tz.values()[:-1] + ) diff --git a/darts/tests/utils/test_timeseries_generation.py b/darts/tests/utils/test_timeseries_generation.py index 5342879dcd..606e36d311 100644 --- a/darts/tests/utils/test_timeseries_generation.py +++ b/darts/tests/utils/test_timeseries_generation.py @@ -4,9 +4,11 @@ import pandas as pd import pytest +from darts import TimeSeries from darts.utils.timeseries_generation import ( autoregressive_timeseries, constant_timeseries, + datetime_attribute_timeseries, gaussian_timeseries, generate_index, holidays_timeseries, @@ -183,6 +185,27 @@ def test_routine( with pytest.raises(ValueError): holidays_timeseries(time_index_3, "US", until=163) + # test non time zone-naive + with pytest.raises(ValueError): + holidays_timeseries(time_index_3.tz_localize("UTC"), "US", until=163) + + # test holiday with and without time zone, 1st of August is national holiday in Switzerland + # time zone naive (e.g. in UTC) + idx = generate_index( + start=pd.Timestamp("2000-07-31 22:00:00"), length=3, freq="h" + ) + ts = holidays_timeseries(idx, country_code="CH") + np.testing.assert_array_almost_equal(ts.values()[:, 0], np.array([0, 0, 1])) + + # time zone CET (+2 hour compared to UTC) + ts = holidays_timeseries(idx, country_code="CH", tz="CET") + np.testing.assert_array_almost_equal(ts.values()[:, 0], np.array([1, 1, 1])) + + # check same from TimeSeries + series = TimeSeries.from_times_and_values(times=idx, values=np.arange(len(idx))) + ts = holidays_timeseries(series, country_code="CH", tz="CET") + np.testing.assert_array_almost_equal(ts.values()[:, 0], np.array([1, 1, 1])) + def test_generate_index(self): def test_routine( expected_length, @@ -321,3 +344,93 @@ def test_calculation(coef): for coef_assert in [[-1], [-1, 1.618], [1, 2, 3], list(range(10))]: test_calculation(coef=coef_assert) + + def test_datetime_attribute_timeseries(self): + idx = generate_index(start=pd.Timestamp("2000-01-01"), length=48, freq="h") + + def helper_routine(idx, attr, vals_exp, **kwargs): + ts = datetime_attribute_timeseries(idx, attribute=attr, **kwargs) + vals_exp = np.array(vals_exp, dtype=ts.dtype) + if len(vals_exp.shape) == 1: + vals_act = ts.values()[:, 0] + else: + vals_act = ts.values() + np.testing.assert_array_almost_equal(vals_act, vals_exp) + + # no pd.DatetimeIndex + with pytest.raises(ValueError) as err: + helper_routine( + pd.RangeIndex(start=0, stop=len(idx)), "h", vals_exp=np.arange(len(idx)) + ) + assert str(err.value).startswith( + "`time_index` must be a pandas `DatetimeIndex`" + ) + + # invalid attribute + with pytest.raises(ValueError) as err: + helper_routine(idx, "h", vals_exp=np.arange(len(idx))) + assert str(err.value).startswith( + "attribute `h` needs to be an attribute of pd.DatetimeIndex." + ) + + # no time zone aware index + with pytest.raises(ValueError) as err: + helper_routine(idx.tz_localize("UTC"), "h", vals_exp=np.arange(len(idx))) + assert "`time_index` must be time zone naive." == str(err.value) + + # ===> datetime attribute + # hour + vals = [i for i in range(24)] * 2 + helper_routine(idx, "hour", vals_exp=vals) + + # hour from TimeSeries + helper_routine( + TimeSeries.from_times_and_values(times=idx, values=np.arange(len(idx))), + "hour", + vals_exp=vals, + ) + + # tz=CET is +1 hour to UTC + vals = vals[1:] + [0] + helper_routine(idx, "hour", vals_exp=vals, tz="CET") + + # day + vals = [1] * 24 + [2] * 24 + helper_routine(idx, "day", vals_exp=vals) + + # dayofweek + vals = [5] * 24 + [6] * 24 + helper_routine(idx, "dayofweek", vals_exp=vals) + + # month + vals = [1] * 48 + helper_routine(idx, "month", vals_exp=vals) + + # ===> one hot encoded + # month + vals = [1] + [0] * 11 + vals = [vals for _ in range(48)] + helper_routine(idx, "month", vals_exp=vals, one_hot=True) + + # tz=CET, month + vals = [1] + [0] * 11 + vals = [vals for _ in range(48)] + helper_routine(idx, "month", vals_exp=vals, tz="CET", one_hot=True) + + # ===> sine/cosine cyclic encoding + # hour (period = 24 hours in one day) + period = 24 + freq = 2 * np.pi / period + vals_dta = [i for i in range(24)] * 2 + vals = np.array(vals_dta) + sin_vals = np.sin(freq * vals)[:, None] + cos_vals = np.cos(freq * vals)[:, None] + vals = np.concatenate([sin_vals, cos_vals], axis=1) + helper_routine(idx, "hour", vals_exp=vals, cyclic=True) + + # tz=CET, hour + vals = np.array(vals_dta[1:] + [0]) + sin_vals = np.sin(freq * vals)[:, None] + cos_vals = np.cos(freq * vals)[:, None] + vals = np.concatenate([sin_vals, cos_vals], axis=1) + helper_routine(idx, "hour", vals_exp=vals, tz="CET", cyclic=True) diff --git a/darts/timeseries.py b/darts/timeseries.py index b75ccefb7e..e841ad60fe 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -2991,7 +2991,11 @@ def univariate_component(self, index: Union[str, int]) -> Self: return self[index if isinstance(index, str) else self.components[index]] def add_datetime_attribute( - self, attribute, one_hot: bool = False, cyclic: bool = False + self, + attribute, + one_hot: bool = False, + cyclic: bool = False, + tz: Optional[str] = None, ) -> "TimeSeries": """ Build a new series with one (or more) additional component(s) that contain an attribute @@ -3012,6 +3016,8 @@ def add_datetime_attribute( Boolean value indicating whether to add the specified attribute as a cyclic encoding. Alternative to one_hot encoding, enable only one of the two. (adds 2 columns, corresponding to sin and cos transformation). + tz + Optionally, a time zone to convert the time index to before computing the attributes. Returns ------- @@ -3023,12 +3029,20 @@ def add_datetime_attribute( return self.stack( tg.datetime_attribute_timeseries( - self.time_index, attribute, one_hot, cyclic + self.time_index, + attribute=attribute, + one_hot=one_hot, + cyclic=cyclic, + tz=tz, ) ) def add_holidays( - self, country_code: str, prov: str = None, state: str = None + self, + country_code: str, + prov: str = None, + state: str = None, + tz: Optional[str] = None, ) -> "TimeSeries": """ Adds a binary univariate component to the current series that equals 1 at every index that @@ -3048,6 +3062,8 @@ def add_holidays( The province state The state + tz + Optionally, a time zone to convert the time index to before computing the attributes. Returns ------- @@ -3058,7 +3074,13 @@ def add_holidays( from .utils import timeseries_generation as tg return self.stack( - tg.holidays_timeseries(self.time_index, country_code, prov, state) + tg.holidays_timeseries( + self.time_index, + country_code=country_code, + prov=prov, + state=state, + tz=tz, + ) ) def resample(self, freq: str, method: str = "pad", **kwargs) -> Self: diff --git a/darts/utils/timeseries_generation.py b/darts/utils/timeseries_generation.py index d07aaf164a..da1d2a524c 100644 --- a/darts/utils/timeseries_generation.py +++ b/darts/utils/timeseries_generation.py @@ -4,7 +4,7 @@ """ import math -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import holidays import numpy as np @@ -528,7 +528,7 @@ def _extend_time_index_until( def holidays_timeseries( - time_index: pd.DatetimeIndex, + time_index: Union[TimeSeries, pd.DatetimeIndex], country_code: str, prov: str = None, state: str = None, @@ -536,6 +536,7 @@ def holidays_timeseries( until: Optional[Union[int, str, pd.Timestamp]] = None, add_length: int = 0, dtype: np.dtype = np.float64, + tz: Optional[str] = None, ) -> TimeSeries: """ Creates a binary univariate TimeSeries with index `time_index` that equals 1 at every index that lies within @@ -546,13 +547,13 @@ def holidays_timeseries( Parameters ---------- time_index - The time index over which to generate the holidays + Either a `pd.DatetimeIndex` or a `TimeSeries` for which to generate the holidays. country_code - The country ISO code + The country ISO code. prov - The province + The province. state - The state + The state. until Extend the time_index up until timestamp for datetime indexed series and int for range indexed series, should match or exceed forecasting window. @@ -560,17 +561,24 @@ def holidays_timeseries( Extend the time_index by add_length, should match or exceed forecasting window. Set only one of until and add_length. column_name - Optionally, the name of the value column for the returned TimeSeries + Optionally, the name of the value column for the returned TimeSeries. dtype - The desired NumPy dtype (np.float32 or np.float64) for the resulting series + The desired NumPy dtype (np.float32 or np.float64) for the resulting series. + tz + Optionally, a time zone to convert the time index to before generating the holidays. Returns ------- TimeSeries A new binary holiday TimeSeries instance. """ + time_index_ts, time_index = _process_time_index( + time_index=time_index, + tz=tz, + until=until, + add_length=add_length, + ) - time_index = _extend_time_index_until(time_index, until, add_length) scope = range(time_index[0].year, (time_index[-1] + pd.Timedelta(days=1)).year) country_holidays = holidays.country_holidays( country_code, prov=prov, state=state, years=scope @@ -578,7 +586,7 @@ def holidays_timeseries( index_series = pd.Series(time_index, index=time_index) values = index_series.apply(lambda x: x in country_holidays).astype(dtype) return TimeSeries.from_times_and_values( - time_index, values, columns=pd.Index([column_name]) + time_index_ts, values, columns=pd.Index([column_name]) ) @@ -591,6 +599,7 @@ def datetime_attribute_timeseries( add_length: int = 0, dtype=np.float64, with_columns: Optional[Union[List[str], str]] = None, + tz: Optional[str] = None, ) -> TimeSeries: """ Returns a new TimeSeries with index `time_index` and one or more dimensions containing @@ -628,6 +637,8 @@ def datetime_attribute_timeseries( cosine component name. * If `one_hot` is ``True``, must be a list of strings of the same length as the generated one hot encoded features. + tz + Optionally, a time zone to convert the time index to before computing the attributes. Returns ------- @@ -635,10 +646,12 @@ def datetime_attribute_timeseries( New datetime attribute TimeSeries instance. """ - if isinstance(time_index, TimeSeries): - time_index = time_index.time_index - - time_index = _extend_time_index_until(time_index, until, add_length) + time_index_ts, time_index = _process_time_index( + time_index=time_index, + tz=tz, + until=until, + add_length=add_length, + ) raise_if_not( hasattr(pd.DatetimeIndex, attribute) @@ -743,7 +756,7 @@ def datetime_attribute_timeseries( ) values_df = pd.DataFrame({with_columns: values}) - values_df.index = time_index + values_df.index = time_index_ts return TimeSeries.from_dataframe(values_df).astype(dtype) @@ -818,3 +831,42 @@ def _generate_new_dates( return generate_index( start=start, freq=input_series.freq, length=n, name=input_series.time_dim ) + + +def _process_time_index( + time_index: Union[TimeSeries, pd.DatetimeIndex], + tz: Optional[str] = None, + until: Optional[Union[int, str, pd.Timestamp]] = None, + add_length: int = 0, +) -> Tuple[pd.DatetimeIndex, pd.DatetimeIndex]: + """ + Extracts the time index, and optionally adds some time steps after the end of the index, and/or converts the + index to another time zone. + + Returns a tuple of pd.DatetimeIndex with the first being the naive time index for generating a new TimeSeries, + and the second being the one used for generating datetime attributes and holidays in a potentially different + time zone. + """ + if isinstance(time_index, TimeSeries): + time_index = time_index.time_index + + if not isinstance(time_index, pd.DatetimeIndex): + raise_log( + ValueError( + "`time_index` must be a pandas `DatetimeIndex` or a `TimeSeries` indexed with a `DatetimeIndex`." + ), + logger=logger, + ) + if time_index.tz is not None: + raise_log( + ValueError("`time_index` must be time zone naive."), + logger=logger, + ) + time_index = _extend_time_index_until(time_index, until, add_length) + + # convert to another time zone + if tz is not None: + time_index_ = time_index.tz_localize("UTC").tz_convert(tz) + else: + time_index_ = time_index + return time_index, time_index_ From af5b141893d8254ca427fcb770f3056a9d9830ae Mon Sep 17 00:00:00 2001 From: Freddie Hsin-Fu Huang Date: Mon, 6 Nov 2023 23:39:05 +0800 Subject: [PATCH 2/4] fix: torch_forecasting_model load_weights with float16 and float32 (#2046) --- CHANGELOG.md | 1 + .../forecasting/torch_forecasting_model.py | 9 +++++++- .../test_torch_forecasting_model.py | 23 +++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a68470899..4d2327303c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Fixed** - Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou). - Fixed a bug when using encoders with `RegressionModel` and series with a non-evenly spaced frequency (e.g. Month Begin). This raised an error during lagged data creation when trying to divide a pd.Timedelta by the ambiguous frequency. [#2034](https://github.com/unit8co/darts/pull/2034) by [Antoine Madrona](https://github.com/madtoinou). +- Fixed a bug when loading a `TorchForecastingModel` that was trained with a precision other than `float64`. [#2046](https://github.com/unit8co/darts/pull/2046) by [Freddie Hsin-Fu Huang](https://github.com/Hsinfu). ### For developers of the library: diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 8312808689..28cc5c8d0f 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -107,6 +107,12 @@ RUNS_FOLDER = "runs" INIT_MODEL_NAME = "_model.pth.tar" +TORCH_NP_DTYPES = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, +} + # pickling a TorchForecastingModel will not save below attributes: the keys specify the # attributes to be ignored, and the values are the default values getting assigned upon loading TFM_ATTRS_NO_PICKLE = {"model": None, "trainer": None} @@ -1872,8 +1878,9 @@ def load_weights_from_checkpoint( ) # pl_forecasting module saves the train_sample shape, must recreate one + np_dtype = TORCH_NP_DTYPES[ckpt["model_dtype"]] mock_train_sample = [ - np.zeros(sample_shape) if sample_shape else None + np.zeros(sample_shape, dtype=np_dtype) if sample_shape else None for sample_shape in ckpt["train_sample_shape"] ] self.train_sample = tuple(mock_train_sample) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 400899b76a..6a150bbfa3 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1045,6 +1045,29 @@ def test_load_weights(self, tmpdir_fn): f"respectively {retrained_mape} and {original_mape}" ) + def test_load_weights_with_float32_dtype(self, tmpdir_fn): + ts_float32 = self.series.astype("float32") + model_name = "test_model" + ckpt_path = os.path.join(tmpdir_fn, f"{model_name}.pt") + # barebone model + model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + n_epochs=1, + ) + model.fit(ts_float32) + model.save(ckpt_path) + assert model.model._dtype == torch.float32 # type: ignore + + # identical model + loading_model = DLinearModel( + input_chunk_length=4, + output_chunk_length=1, + ) + loading_model.load_weights(ckpt_path) + loading_model.fit(ts_float32) + assert loading_model.model._dtype == torch.float32 # type: ignore + def test_multi_steps_pipeline(self, tmpdir_fn): ts_training, ts_val = self.series.split_before(75) pretrain_model_name = "pre-train" From a5a4306f88b2f548e55665551ba9a4c101dff3fd Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Wed, 8 Nov 2023 08:39:35 +0100 Subject: [PATCH 3/4] fix num_samles not being passed to optimized hist fc rountine of torch models (#2060) --- .../historical_forecasts/optimized_historical_forecasts_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py index 085c5efa64..f41cbbfdfb 100644 --- a/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py +++ b/darts/utils/historical_forecasts/optimized_historical_forecasts_torch.py @@ -115,6 +115,7 @@ def _optimized_historical_forecasts( dataset, trainer=None, verbose=verbose, + num_samples=num_samples, predict_likelihood_parameters=predict_likelihood_parameters, ) From da049e56d95c4d73de334d9ff41457c8962fef36 Mon Sep 17 00:00:00 2001 From: madtoinou <32447896+madtoinou@users.noreply.github.com> Date: Wed, 8 Nov 2023 08:51:25 +0100 Subject: [PATCH 4/4] Fix/exp smooth constructor args (#2059) * feat: adding support for constructor kwargs * feat: adding tests * fix: udpated representation test for ExponentialSmoothing model * update changelog.md --------- Co-authored-by: dennisbader --- CHANGELOG.md | 1 + .../forecasting/exponential_smoothing.py | 12 ++- .../forecasting/test_exponential_smoothing.py | 79 +++++++++++++++---- .../test_local_forecasting_models.py | 2 +- 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d2327303c..0fa67fd3de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader). - Other improvements: - Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader). + - Added optional keyword arguments dict `kwargs` to `ExponentialSmoothing` that will be passed to the constructor of the underlying `statsmodels.tsa.holtwinters.ExponentialSmoothing` model. [#2059](https://github.com/unit8co/darts/pull/2059) by [Antoine Madrona](https://github.com/madtoinou). **Fixed** - Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou). diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index dda34b6992..a847b155d5 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -3,7 +3,7 @@ --------------------- """ -from typing import Optional +from typing import Any, Dict, Optional import numpy as np import statsmodels.tsa.holtwinters as hw @@ -24,7 +24,8 @@ def __init__( seasonal: Optional[SeasonalityMode] = SeasonalityMode.ADDITIVE, seasonal_periods: Optional[int] = None, random_state: int = 0, - **fit_kwargs, + kwargs: Optional[Dict[str, Any]] = None, + **fit_kwargs ): """Exponential Smoothing @@ -61,6 +62,11 @@ def __init__( seasonal_periods The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily data with a weekly cycle. If not set, inferred from frequency of the series. + kwargs + Some optional keyword arguments that will be used to call + :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing()`. + See `the documentation + `_. fit_kwargs Some optional keyword arguments that will be used to call :func:`statsmodels.tsa.holtwinters.ExponentialSmoothing.fit()`. @@ -91,6 +97,7 @@ def __init__( self.seasonal = seasonal self.infer_seasonal_periods = seasonal_periods is None self.seasonal_periods = seasonal_periods + self.constructor_kwargs = dict() if kwargs is None else kwargs self.fit_kwargs = fit_kwargs self.model = None np.random.seed(random_state) @@ -120,6 +127,7 @@ def fit(self, series: TimeSeries): seasonal_periods=seasonal_periods_param, freq=series.freq if series.has_datetime_index else None, dates=series.time_index if series.has_datetime_index else None, + **self.constructor_kwargs ) hw_results = hw_model.fit(**self.fit_kwargs) self.model = hw_results diff --git a/darts/tests/models/forecasting/test_exponential_smoothing.py b/darts/tests/models/forecasting/test_exponential_smoothing.py index 173a2ba508..63b494ae44 100644 --- a/darts/tests/models/forecasting/test_exponential_smoothing.py +++ b/darts/tests/models/forecasting/test_exponential_smoothing.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from darts import TimeSeries from darts.models import ExponentialSmoothing @@ -6,36 +7,86 @@ class TestExponentialSmoothing: - def helper_test_seasonality_inference(self, freq_string, expected_seasonal_periods): - series = tg.sine_timeseries(length=200, freq=freq_string) - model = ExponentialSmoothing() - model.fit(series) - assert model.seasonal_periods == expected_seasonal_periods + series = tg.sine_timeseries(length=100, freq="H") - def test_seasonality_inference(self): - - # test `seasonal_periods` inference for datetime indices - freq_str_seasonality_periods_tuples = [ + @pytest.mark.parametrize( + "freq_string,expected_seasonal_periods", + [ ("D", 7), ("H", 24), ("M", 12), ("W", 52), ("Q", 4), ("B", 5), - ] - for tuple in freq_str_seasonality_periods_tuples: - self.helper_test_seasonality_inference(*tuple) + ], + ) + def test_seasonality_inference( + self, freq_string: str, expected_seasonal_periods: int + ): + series = tg.sine_timeseries(length=200, freq=freq_string) + model = ExponentialSmoothing() + model.fit(series) + assert model.seasonal_periods == expected_seasonal_periods - # test default selection for integer index + def test_default_parameters(self): + """Test default selection for integer index""" series = TimeSeries.from_values(np.arange(1, 30, 1)) model = ExponentialSmoothing() model.fit(series) assert model.seasonal_periods == 12 - # test whether a model that inferred a seasonality period before will do it again for a new series + def test_multiple_fit(self): + """Test whether a model that inferred a seasonality period before will do it again for a new series""" series1 = tg.sine_timeseries(length=100, freq="M") series2 = tg.sine_timeseries(length=100, freq="D") model = ExponentialSmoothing() model.fit(series1) model.fit(series2) assert model.seasonal_periods == 7 + + def test_constructor_kwargs(self): + """Using kwargs to pass additional parameters to the constructor""" + constructor_kwargs = { + "initialization_method": "known", + "initial_level": 0.5, + "initial_trend": 0.2, + "initial_seasonal": np.arange(1, 25), + } + model = ExponentialSmoothing(kwargs=constructor_kwargs) + model.fit(self.series) + # must be checked separately, name is not consistent + np.testing.assert_array_almost_equal( + model.model.model.params["initial_seasons"], + constructor_kwargs["initial_seasonal"], + ) + for param_name in ["initial_level", "initial_trend"]: + assert ( + model.model.model.params[param_name] == constructor_kwargs[param_name] + ) + + def test_fit_kwargs(self): + """Using kwargs to pass additional parameters to the fit()""" + # using default optimization method + model = ExponentialSmoothing() + model.fit(self.series) + assert model.fit_kwargs == {} + pred = model.predict(n=2) + + model_bis = ExponentialSmoothing() + model_bis.fit(self.series) + assert model_bis.fit_kwargs == {} + pred_bis = model_bis.predict(n=2) + + # two methods with the same parameters should yield the same forecasts + assert pred.time_index.equals(pred_bis.time_index) + np.testing.assert_array_almost_equal(pred.values(), pred_bis.values()) + + # change optimization method + model_ls = ExponentialSmoothing(method="least_squares") + model_ls.fit(self.series) + assert model_ls.fit_kwargs == {"method": "least_squares"} + pred_ls = model_ls.predict(n=2) + + # forecasts should be slightly different + assert pred.time_index.equals(pred_ls.time_index) + assert all(np.not_equal(pred.values(), pred_ls.values())) diff --git a/darts/tests/models/forecasting/test_local_forecasting_models.py b/darts/tests/models/forecasting/test_local_forecasting_models.py index 557883b4c0..f3ac21d40d 100644 --- a/darts/tests/models/forecasting/test_local_forecasting_models.py +++ b/darts/tests/models/forecasting/test_local_forecasting_models.py @@ -651,7 +651,7 @@ def test_model_str_call(self, config): ( ExponentialSmoothing(), "ExponentialSmoothing(trend=ModelMode.ADDITIVE, damped=False, seasonal=SeasonalityMode.ADDITIVE, " - + "seasonal_periods=None, random_state=0)", + + "seasonal_periods=None, random_state=0, kwargs=None)", ), # no params changed ( ARIMA(1, 1, 1),