From c498405cdab004e7004ecc586f623ed343cc1d43 Mon Sep 17 00:00:00 2001 From: Dennis Bader <dennis.bader@gmx.ch> Date: Mon, 24 Jun 2024 16:26:13 +0200 Subject: [PATCH] Fix/regr model hist fc static covs no target lags (#2426) * fix bug where regression model historical forecasts with static covariates and no target lags failed * fix issues with xgboost v2.1.0 * update changelog * add unit tests for regression model hist fc --- CHANGELOG.md | 2 + darts/models/forecasting/catboost_model.py | 2 +- darts/models/forecasting/lgbm.py | 3 +- .../forecasting/linear_regression_model.py | 2 +- darts/models/forecasting/regression_model.py | 38 ++++++++---- darts/models/forecasting/xgboost.py | 2 +- .../forecasting/test_regression_models.py | 60 ++++++++++++++++++- darts/utils/data/tabularization.py | 22 ++++++- ...timized_historical_forecasts_regression.py | 3 + 9 files changed, 114 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d528e013e..637594cdfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** **Fixed** +- Fixed a bug when using `historical_forecasts()` with a pre-trained `RegressionModel` that has no target lags `lags=None` but uses static covariates. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader). +- Fixed a bug with `xgboost>=2.1.0`, where multi output regression was not properly handled. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader). **Dependencies** diff --git a/darts/models/forecasting/catboost_model.py b/darts/models/forecasting/catboost_model.py index 1ed9580a97..883ac94f6b 100644 --- a/darts/models/forecasting/catboost_model.py +++ b/darts/models/forecasting/catboost_model.py @@ -165,7 +165,7 @@ def encode_year(idx): self._median_idx = None self._model_container = None self._rng = None - self.likelihood = likelihood + self._likelihood = likelihood self.quantiles = None self._output_chunk_length = output_chunk_length diff --git a/darts/models/forecasting/lgbm.py b/darts/models/forecasting/lgbm.py index 512cf4b264..10d4de64d6 100644 --- a/darts/models/forecasting/lgbm.py +++ b/darts/models/forecasting/lgbm.py @@ -188,7 +188,7 @@ def encode_year(idx): self._median_idx = None self._model_container = None self.quantiles = None - self.likelihood = likelihood + self._likelihood = likelihood self._rng = None # parse likelihood @@ -294,7 +294,6 @@ def fit( val_sample_weight=val_sample_weight, **kwargs, ) - self._model_container[quantile] = self.model return self diff --git a/darts/models/forecasting/linear_regression_model.py b/darts/models/forecasting/linear_regression_model.py index 36dabe0d5e..e5d856299a 100644 --- a/darts/models/forecasting/linear_regression_model.py +++ b/darts/models/forecasting/linear_regression_model.py @@ -174,7 +174,7 @@ def encode_year(idx): self._median_idx = None self._model_container = None self.quantiles = None - self.likelihood = likelihood + self._likelihood = likelihood self._rng = None # parse likelihood diff --git a/darts/models/forecasting/regression_model.py b/darts/models/forecasting/regression_model.py index 91bd9ba5a8..79227114fe 100644 --- a/darts/models/forecasting/regression_model.py +++ b/darts/models/forecasting/regression_model.py @@ -831,30 +831,42 @@ def fit( } # if multi-output regression + use_mor = False if not series[0].is_univariate or ( self.output_chunk_length > 1 and self.multi_models and not isinstance(self.model, MultiOutputRegressor) ): - val_set_name, val_weight_name = self.val_set_params - mor_kwargs = { - "eval_set_name": val_set_name, - "eval_weight_name": val_weight_name, - "n_jobs": n_jobs_multioutput_wrapper, - } if sample_weight is not None: # we have 2D sample (and time) weights, only supported in Darts - self.model = MultiOutputRegressor(self.model, **mor_kwargs) + use_mor = True elif not ( callable(getattr(self.model, "_get_tags", None)) and isinstance(self.model._get_tags(), dict) and self.model._get_tags().get("multioutput") ): # model does not support multi-output regression natively - self.model = MultiOutputRegressor(self.model, **mor_kwargs) - elif self.model.__class__.__name__ == "CatBoostRegressor": - if self.model.get_params()["loss_function"] == "RMSEWithUncertainty": - self.model = MultiOutputRegressor(self.model, **mor_kwargs) + use_mor = True + elif ( + self.model.__class__.__name__ == "CatBoostRegressor" + and self.model.get_params()["loss_function"] == "RMSEWithUncertainty" + ): + use_mor = True + elif ( + self.model.__class__.__name__ == "XGBRegressor" + and self.likelihood is not None + ): + # since xgboost==2.1.0, likelihoods do not support native multi output regression + use_mor = True + + if use_mor: + val_set_name, val_weight_name = self.val_set_params + mor_kwargs = { + "eval_set_name": val_set_name, + "eval_weight_name": val_weight_name, + "n_jobs": n_jobs_multioutput_wrapper, + } + self.model = MultiOutputRegressor(self.model, **mor_kwargs) # warn if n_jobs_multioutput_wrapper was provided but not used if ( @@ -1226,6 +1238,10 @@ def lagged_label_names(self) -> Optional[List[str]]: def __str__(self): return self.model.__str__() + @property + def likelihood(self) -> Optional[str]: + return getattr(self, "_likelihood", None) + @property def supports_past_covariates(self) -> bool: return len(self.lags.get("past", [])) > 0 diff --git a/darts/models/forecasting/xgboost.py b/darts/models/forecasting/xgboost.py index 79bfdcd27d..bd7bd25c52 100644 --- a/darts/models/forecasting/xgboost.py +++ b/darts/models/forecasting/xgboost.py @@ -193,7 +193,7 @@ def encode_year(idx): self._median_idx = None self._model_container = None self.quantiles = None - self.likelihood = likelihood + self._likelihood = likelihood self._rng = None # parse likelihood diff --git a/darts/tests/models/forecasting/test_regression_models.py b/darts/tests/models/forecasting/test_regression_models.py index 176dedf80a..4266af7b1a 100644 --- a/darts/tests/models/forecasting/test_regression_models.py +++ b/darts/tests/models/forecasting/test_regression_models.py @@ -1315,7 +1315,7 @@ def test_multioutput_wrapper(self, config): horizon=0, target_dim=1 ) - model_configs = [(XGBModel, dict({"tree_method": "exact"}, **xgb_test_params))] + model_configs = [(XGBModel, dict({"likelihood": "poisson"}, **xgb_test_params))] if lgbm_available: model_configs += [(LightGBMModel, lgbm_test_params)] if cb_available: @@ -1341,7 +1341,15 @@ def test_get_multioutput_estimator_multi_models(self): """Craft training data so that estimator_[i].predict(X) == i + 1""" def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int): - m = XGBModel(lags=3, output_chunk_length=ocl, multi_models=True) + # since xgboost==2.1.0, the regular deterministic models have native multi output regression + # -> we use a quantile likelihood to activate Darts' MultiOutputRegressor + m = XGBModel( + lags=3, + output_chunk_length=ocl, + multi_models=True, + likelihood="quantile", + quantiles=[0.5], + ) m.fit(ts) assert len(m.model.estimators_) == ocl * ts.width @@ -1401,7 +1409,15 @@ def test_get_multioutput_estimator_single_model(self): # estimators_[0] labels : [1] # estimators_[1] labels : [2] - m = XGBModel(lags=3, output_chunk_length=ocl, multi_models=False) + # since xgboost==2.1.0, the regular deterministic models have native multi output regression + # -> we use a quantile likelihood to activate Darts' MultiOutputRegressor + m = XGBModel( + lags=3, + output_chunk_length=ocl, + multi_models=False, + likelihood="quantile", + quantiles=[0.5], + ) m.fit(ts) # one estimator is reused for all the horizon of a given component @@ -2593,6 +2609,44 @@ def test_output_shift(self, config): hist_fc_opt.values(copy=False), pred_last_hist_fc[-1].values(copy=False) ) + @pytest.mark.parametrize("lpo", [True, False]) + def test_historical_forecasts_no_target_lags_with_static_covs(self, lpo): + """Tests that historical forecasts work without target lags but with static covariates. + For last_points_only `True` and `False`.""" + ocl = 7 + series = tg.linear_timeseries( + length=28, start=pd.Timestamp("2000-01-01"), freq="d" + ).with_static_covariates(pd.Series([1.0])) + + model = LinearRegressionModel( + lags=None, + lags_future_covariates=(3, 0), + output_chunk_length=ocl, + use_static_covariates=True, + ) + model.fit(series, future_covariates=series) + + preds1 = model.historical_forecasts( + series, + future_covariates=series, + retrain=False, + enable_optimization=True, + last_points_only=lpo, + ) + preds2 = model.historical_forecasts( + series, + future_covariates=series, + retrain=False, + enable_optimization=False, + last_points_only=lpo, + ) + if lpo: + preds1 = [preds1] + preds2 = [preds2] + + for p1, p2 in zip(preds1, preds2): + np.testing.assert_array_almost_equal(p1.values(), p2.values()) + @pytest.mark.parametrize( "config", itertools.product( diff --git a/darts/utils/data/tabularization.py b/darts/utils/data/tabularization.py index f3ba540cc9..157b0f90c5 100644 --- a/darts/utils/data/tabularization.py +++ b/darts/utils/data/tabularization.py @@ -43,6 +43,7 @@ def create_lagged_data( is_training: bool = True, concatenate: bool = True, sample_weight: Optional[Union[str, TimeSeries, Sequence[TimeSeries]]] = None, + show_warnings: bool = True, ) -> Tuple[ ArrayOrArraySequence, Union[None, ArrayOrArraySequence], @@ -224,6 +225,8 @@ def create_lagged_data( `"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are computed globally based on the length of the longest series in `series`. Then for each series, the weights are extracted from the end of the global weights. This gives a common time weighting across all series. + show_warnings + Whether to show warnings. Returns ------- @@ -359,6 +362,7 @@ def create_lagged_data( multi_models=multi_models, check_inputs=check_inputs, is_training=is_training, + show_warnings=show_warnings, ) else: X_i, y_i, times_i, weights_i = _create_lagged_data_by_intersecting_times( @@ -375,6 +379,7 @@ def create_lagged_data( multi_models=multi_models, check_inputs=check_inputs, is_training=is_training, + show_warnings=show_warnings, ) X_i, last_static_covariates_shape = add_static_covariates_to_lagged_data( features=X_i, @@ -576,6 +581,7 @@ def create_lagged_prediction_data( check_inputs: bool = True, use_moving_windows: bool = True, concatenate: bool = True, + show_warnings: bool = True, ) -> Tuple[ArrayOrArraySequence, Sequence[pd.Index]]: """ Creates the features array `X` to produce a series of prediction from an already-trained regression model; the @@ -640,6 +646,8 @@ def create_lagged_prediction_data( `Sequence[TimeSeries]` are provided, then `X` will be an array created by concatenating all feature arrays formed by each `TimeSeries` along the `0`th axis. Note that `times` is still returned as `Sequence[pd.Index]`, even when `concatenate = True`. + show_warnings + Whether to show warnings. Returns ------- @@ -680,6 +688,7 @@ def create_lagged_prediction_data( use_moving_windows=use_moving_windows, is_training=False, concatenate=concatenate, + show_warnings=show_warnings, ) return X, times @@ -963,6 +972,7 @@ def _create_lagged_data_by_moving_window( multi_models: bool, check_inputs: bool, is_training: bool, + show_warnings: bool = True, ) -> Tuple[np.ndarray, Optional[np.ndarray], pd.Index, Optional[np.ndarray]]: """ Helper function called by `create_lagged_data` that computes `X`, `y`, and `times` by @@ -991,6 +1001,7 @@ def _create_lagged_data_by_moving_window( is_training=is_training, return_min_and_max_lags=True, check_inputs=check_inputs, + show_warnings=show_warnings, ) if check_inputs: series_and_lags_not_specified = [max_lag is None for max_lag in max_lags] @@ -1197,6 +1208,7 @@ def _create_lagged_data_by_intersecting_times( multi_models: bool, check_inputs: bool, is_training: bool, + show_warnings: bool = True, ) -> Tuple[ np.ndarray, Optional[np.ndarray], @@ -1224,6 +1236,7 @@ def _create_lagged_data_by_intersecting_times( is_training=is_training, return_min_and_max_lags=True, check_inputs=check_inputs, + show_warnings=show_warnings, ) if check_inputs: series_and_lags_not_specified = [min_lag is None for min_lag in min_lags] @@ -1460,6 +1473,7 @@ def _get_feature_times( is_training: bool = True, return_min_and_max_lags: bool = False, check_inputs: bool = True, + show_warnings: bool = True, ) -> Union[FeatureTimes, Tuple[FeatureTimes, MinLags, MaxLags]]: """ Returns a tuple containing the times in `target_series`, the times in `past_covariates`, and the times in @@ -1571,6 +1585,8 @@ def _get_feature_times( return_min_and_max_lags Optionally, specifies whether the largest magnitude lag value for each series should also be returned along with the 'eligible' feature times + show_warnings + Whether to show warnings. Note: if the lags are provided as a dictionary for the target series or any of the covariates series, the component-specific lags are grouped into a single list to compute the corresponding feature time. @@ -1673,7 +1689,11 @@ def _get_feature_times( # `target_series`/`past_covariates` in `Notes`: if max_lag_i > 0: times_i = times_i[max_lag_i:] - elif (not is_label_series) and (series_specified ^ lags_specified): + elif ( + show_warnings + and (not is_label_series) + and (series_specified ^ lags_specified) + ): # Warn user that series/lags input will be ignored: times_i = max_lag_i = None lags_name = "lags" if name_i == "target_series" else f"lags_{name_i}" diff --git a/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py b/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py index 86fb04e28a..c480d69c57 100644 --- a/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py +++ b/darts/utils/historical_forecasts/optimized_historical_forecasts_regression.py @@ -105,6 +105,7 @@ def _optimized_historical_forecasts_last_points_only( target_series=( None if model._get_lags("target") is None + and not model.uses_static_covariates else series_[hist_fct_tgt_start:hist_fct_tgt_end] ), past_covariates=( @@ -260,6 +261,7 @@ def _optimized_historical_forecasts_all_points( target_series=( None if model._get_lags("target") is None + and not model.uses_static_covariates else series_[hist_fct_tgt_start:hist_fct_tgt_end] ), past_covariates=( @@ -281,6 +283,7 @@ def _optimized_historical_forecasts_all_points( check_inputs=True, use_moving_windows=True, concatenate=False, + show_warnings=False, ) # stride must be applied post-hoc to avoid missing values