Skip to content

Commit

Permalink
Fix/regr model hist fc static covs no target lags (#2426)
Browse files Browse the repository at this point in the history
* fix bug where regression model historical forecasts with static covariates and no target lags failed

* fix issues with xgboost v2.1.0

* update changelog

* add unit tests for regression model hist fc
  • Loading branch information
dennisbader authored Jun 24, 2024
1 parent 3115bb6 commit c498405
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

**Fixed**
- Fixed a bug when using `historical_forecasts()` with a pre-trained `RegressionModel` that has no target lags `lags=None` but uses static covariates. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug with `xgboost>=2.1.0`, where multi output regression was not properly handled. [#2426](https://github.com/unit8co/darts/pull/2426) by [Dennis Bader](https://github.com/dennisbader).

**Dependencies**

Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def encode_year(idx):
self._median_idx = None
self._model_container = None
self._rng = None
self.likelihood = likelihood
self._likelihood = likelihood
self.quantiles = None

self._output_chunk_length = output_chunk_length
Expand Down
3 changes: 1 addition & 2 deletions darts/models/forecasting/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def encode_year(idx):
self._median_idx = None
self._model_container = None
self.quantiles = None
self.likelihood = likelihood
self._likelihood = likelihood
self._rng = None

# parse likelihood
Expand Down Expand Up @@ -294,7 +294,6 @@ def fit(
val_sample_weight=val_sample_weight,
**kwargs,
)

self._model_container[quantile] = self.model
return self

Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def encode_year(idx):
self._median_idx = None
self._model_container = None
self.quantiles = None
self.likelihood = likelihood
self._likelihood = likelihood
self._rng = None

# parse likelihood
Expand Down
38 changes: 27 additions & 11 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,30 +831,42 @@ def fit(
}

# if multi-output regression
use_mor = False
if not series[0].is_univariate or (
self.output_chunk_length > 1
and self.multi_models
and not isinstance(self.model, MultiOutputRegressor)
):
val_set_name, val_weight_name = self.val_set_params
mor_kwargs = {
"eval_set_name": val_set_name,
"eval_weight_name": val_weight_name,
"n_jobs": n_jobs_multioutput_wrapper,
}
if sample_weight is not None:
# we have 2D sample (and time) weights, only supported in Darts
self.model = MultiOutputRegressor(self.model, **mor_kwargs)
use_mor = True
elif not (
callable(getattr(self.model, "_get_tags", None))
and isinstance(self.model._get_tags(), dict)
and self.model._get_tags().get("multioutput")
):
# model does not support multi-output regression natively
self.model = MultiOutputRegressor(self.model, **mor_kwargs)
elif self.model.__class__.__name__ == "CatBoostRegressor":
if self.model.get_params()["loss_function"] == "RMSEWithUncertainty":
self.model = MultiOutputRegressor(self.model, **mor_kwargs)
use_mor = True
elif (
self.model.__class__.__name__ == "CatBoostRegressor"
and self.model.get_params()["loss_function"] == "RMSEWithUncertainty"
):
use_mor = True
elif (
self.model.__class__.__name__ == "XGBRegressor"
and self.likelihood is not None
):
# since xgboost==2.1.0, likelihoods do not support native multi output regression
use_mor = True

if use_mor:
val_set_name, val_weight_name = self.val_set_params
mor_kwargs = {
"eval_set_name": val_set_name,
"eval_weight_name": val_weight_name,
"n_jobs": n_jobs_multioutput_wrapper,
}
self.model = MultiOutputRegressor(self.model, **mor_kwargs)

# warn if n_jobs_multioutput_wrapper was provided but not used
if (
Expand Down Expand Up @@ -1226,6 +1238,10 @@ def lagged_label_names(self) -> Optional[List[str]]:
def __str__(self):
return self.model.__str__()

@property
def likelihood(self) -> Optional[str]:
return getattr(self, "_likelihood", None)

@property
def supports_past_covariates(self) -> bool:
return len(self.lags.get("past", [])) > 0
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def encode_year(idx):
self._median_idx = None
self._model_container = None
self.quantiles = None
self.likelihood = likelihood
self._likelihood = likelihood
self._rng = None

# parse likelihood
Expand Down
60 changes: 57 additions & 3 deletions darts/tests/models/forecasting/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def test_multioutput_wrapper(self, config):
horizon=0, target_dim=1
)

model_configs = [(XGBModel, dict({"tree_method": "exact"}, **xgb_test_params))]
model_configs = [(XGBModel, dict({"likelihood": "poisson"}, **xgb_test_params))]
if lgbm_available:
model_configs += [(LightGBMModel, lgbm_test_params)]
if cb_available:
Expand All @@ -1341,7 +1341,15 @@ def test_get_multioutput_estimator_multi_models(self):
"""Craft training data so that estimator_[i].predict(X) == i + 1"""

def helper_check_overfitted_estimators(ts: TimeSeries, ocl: int):
m = XGBModel(lags=3, output_chunk_length=ocl, multi_models=True)
# since xgboost==2.1.0, the regular deterministic models have native multi output regression
# -> we use a quantile likelihood to activate Darts' MultiOutputRegressor
m = XGBModel(
lags=3,
output_chunk_length=ocl,
multi_models=True,
likelihood="quantile",
quantiles=[0.5],
)
m.fit(ts)

assert len(m.model.estimators_) == ocl * ts.width
Expand Down Expand Up @@ -1401,7 +1409,15 @@ def test_get_multioutput_estimator_single_model(self):
# estimators_[0] labels : [1]
# estimators_[1] labels : [2]

m = XGBModel(lags=3, output_chunk_length=ocl, multi_models=False)
# since xgboost==2.1.0, the regular deterministic models have native multi output regression
# -> we use a quantile likelihood to activate Darts' MultiOutputRegressor
m = XGBModel(
lags=3,
output_chunk_length=ocl,
multi_models=False,
likelihood="quantile",
quantiles=[0.5],
)
m.fit(ts)

# one estimator is reused for all the horizon of a given component
Expand Down Expand Up @@ -2593,6 +2609,44 @@ def test_output_shift(self, config):
hist_fc_opt.values(copy=False), pred_last_hist_fc[-1].values(copy=False)
)

@pytest.mark.parametrize("lpo", [True, False])
def test_historical_forecasts_no_target_lags_with_static_covs(self, lpo):
"""Tests that historical forecasts work without target lags but with static covariates.
For last_points_only `True` and `False`."""
ocl = 7
series = tg.linear_timeseries(
length=28, start=pd.Timestamp("2000-01-01"), freq="d"
).with_static_covariates(pd.Series([1.0]))

model = LinearRegressionModel(
lags=None,
lags_future_covariates=(3, 0),
output_chunk_length=ocl,
use_static_covariates=True,
)
model.fit(series, future_covariates=series)

preds1 = model.historical_forecasts(
series,
future_covariates=series,
retrain=False,
enable_optimization=True,
last_points_only=lpo,
)
preds2 = model.historical_forecasts(
series,
future_covariates=series,
retrain=False,
enable_optimization=False,
last_points_only=lpo,
)
if lpo:
preds1 = [preds1]
preds2 = [preds2]

for p1, p2 in zip(preds1, preds2):
np.testing.assert_array_almost_equal(p1.values(), p2.values())

@pytest.mark.parametrize(
"config",
itertools.product(
Expand Down
22 changes: 21 additions & 1 deletion darts/utils/data/tabularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_lagged_data(
is_training: bool = True,
concatenate: bool = True,
sample_weight: Optional[Union[str, TimeSeries, Sequence[TimeSeries]]] = None,
show_warnings: bool = True,
) -> Tuple[
ArrayOrArraySequence,
Union[None, ArrayOrArraySequence],
Expand Down Expand Up @@ -224,6 +225,8 @@ def create_lagged_data(
`"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are
computed globally based on the length of the longest series in `series`. Then for each series, the weights
are extracted from the end of the global weights. This gives a common time weighting across all series.
show_warnings
Whether to show warnings.
Returns
-------
Expand Down Expand Up @@ -359,6 +362,7 @@ def create_lagged_data(
multi_models=multi_models,
check_inputs=check_inputs,
is_training=is_training,
show_warnings=show_warnings,
)
else:
X_i, y_i, times_i, weights_i = _create_lagged_data_by_intersecting_times(
Expand All @@ -375,6 +379,7 @@ def create_lagged_data(
multi_models=multi_models,
check_inputs=check_inputs,
is_training=is_training,
show_warnings=show_warnings,
)
X_i, last_static_covariates_shape = add_static_covariates_to_lagged_data(
features=X_i,
Expand Down Expand Up @@ -576,6 +581,7 @@ def create_lagged_prediction_data(
check_inputs: bool = True,
use_moving_windows: bool = True,
concatenate: bool = True,
show_warnings: bool = True,
) -> Tuple[ArrayOrArraySequence, Sequence[pd.Index]]:
"""
Creates the features array `X` to produce a series of prediction from an already-trained regression model; the
Expand Down Expand Up @@ -640,6 +646,8 @@ def create_lagged_prediction_data(
`Sequence[TimeSeries]` are provided, then `X` will be an array created by concatenating all feature
arrays formed by each `TimeSeries` along the `0`th axis. Note that `times` is still returned as
`Sequence[pd.Index]`, even when `concatenate = True`.
show_warnings
Whether to show warnings.
Returns
-------
Expand Down Expand Up @@ -680,6 +688,7 @@ def create_lagged_prediction_data(
use_moving_windows=use_moving_windows,
is_training=False,
concatenate=concatenate,
show_warnings=show_warnings,
)
return X, times

Expand Down Expand Up @@ -963,6 +972,7 @@ def _create_lagged_data_by_moving_window(
multi_models: bool,
check_inputs: bool,
is_training: bool,
show_warnings: bool = True,
) -> Tuple[np.ndarray, Optional[np.ndarray], pd.Index, Optional[np.ndarray]]:
"""
Helper function called by `create_lagged_data` that computes `X`, `y`, and `times` by
Expand Down Expand Up @@ -991,6 +1001,7 @@ def _create_lagged_data_by_moving_window(
is_training=is_training,
return_min_and_max_lags=True,
check_inputs=check_inputs,
show_warnings=show_warnings,
)
if check_inputs:
series_and_lags_not_specified = [max_lag is None for max_lag in max_lags]
Expand Down Expand Up @@ -1197,6 +1208,7 @@ def _create_lagged_data_by_intersecting_times(
multi_models: bool,
check_inputs: bool,
is_training: bool,
show_warnings: bool = True,
) -> Tuple[
np.ndarray,
Optional[np.ndarray],
Expand Down Expand Up @@ -1224,6 +1236,7 @@ def _create_lagged_data_by_intersecting_times(
is_training=is_training,
return_min_and_max_lags=True,
check_inputs=check_inputs,
show_warnings=show_warnings,
)
if check_inputs:
series_and_lags_not_specified = [min_lag is None for min_lag in min_lags]
Expand Down Expand Up @@ -1460,6 +1473,7 @@ def _get_feature_times(
is_training: bool = True,
return_min_and_max_lags: bool = False,
check_inputs: bool = True,
show_warnings: bool = True,
) -> Union[FeatureTimes, Tuple[FeatureTimes, MinLags, MaxLags]]:
"""
Returns a tuple containing the times in `target_series`, the times in `past_covariates`, and the times in
Expand Down Expand Up @@ -1571,6 +1585,8 @@ def _get_feature_times(
return_min_and_max_lags
Optionally, specifies whether the largest magnitude lag value for each series should also be returned along with
the 'eligible' feature times
show_warnings
Whether to show warnings.
Note: if the lags are provided as a dictionary for the target series or any of the covariates series, the
component-specific lags are grouped into a single list to compute the corresponding feature time.
Expand Down Expand Up @@ -1673,7 +1689,11 @@ def _get_feature_times(
# `target_series`/`past_covariates` in `Notes`:
if max_lag_i > 0:
times_i = times_i[max_lag_i:]
elif (not is_label_series) and (series_specified ^ lags_specified):
elif (
show_warnings
and (not is_label_series)
and (series_specified ^ lags_specified)
):
# Warn user that series/lags input will be ignored:
times_i = max_lag_i = None
lags_name = "lags" if name_i == "target_series" else f"lags_{name_i}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _optimized_historical_forecasts_last_points_only(
target_series=(
None
if model._get_lags("target") is None
and not model.uses_static_covariates
else series_[hist_fct_tgt_start:hist_fct_tgt_end]
),
past_covariates=(
Expand Down Expand Up @@ -260,6 +261,7 @@ def _optimized_historical_forecasts_all_points(
target_series=(
None
if model._get_lags("target") is None
and not model.uses_static_covariates
else series_[hist_fct_tgt_start:hist_fct_tgt_end]
),
past_covariates=(
Expand All @@ -281,6 +283,7 @@ def _optimized_historical_forecasts_all_points(
check_inputs=True,
use_moving_windows=True,
concatenate=False,
show_warnings=False,
)

# stride must be applied post-hoc to avoid missing values
Expand Down

0 comments on commit c498405

Please sign in to comment.