Skip to content

Commit

Permalink
Native quantile regression for xgb 2.0.0 and above (#2051)
Browse files Browse the repository at this point in the history
* use native quantile regression for xgb 2.0.0 and above

* update changelog

* revert change to objective

---------

Co-authored-by: madtoinou <[email protected]>
  • Loading branch information
dennisbader and madtoinou authored Nov 3, 2023
1 parent 3fb1080 commit 2d0233f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- 🚀🚀 Optimized `historical_forecasts()` for pre-trained `TorchForecastingModel` running up to 20 times faster than before!. [#2013](https://github.com/unit8co/darts/pull/2013) by [Dennis Bader](https://github.com/dennisbader).
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Improvements to Regression Models:
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed a bug when calling optimized `historical_forecasts()` for a `RegressionModel` trained with unequal component-specific lags. [#2040](https://github.com/unit8co/darts/pull/2040) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
17 changes: 15 additions & 2 deletions darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

logger = get_logger(__name__)

# Check whether we are running xgboost >= 2.0.0 for quantile regression
tokens = xgb.__version__.split(".")
xgb_200_or_above = int(tokens[0]) >= 2


def xgb_quantile_loss(labels: np.ndarray, preds: np.ndarray, quantile: float):
"""Custom loss function for XGBoost to compute quantile loss gradient.
Expand Down Expand Up @@ -183,8 +187,12 @@ def encode_year(idx):
if likelihood in {"poisson"}:
self.kwargs["objective"] = f"count:{likelihood}"
elif likelihood == "quantile":
if xgb_200_or_above:
# leverage built-in Quantile Regression
self.kwargs["objective"] = "reg:quantileerror"
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
self._model_container = self._get_model_container()

self._rng = np.random.default_rng(seed=random_state) # seed for sampling

super().__init__(
Expand Down Expand Up @@ -249,12 +257,17 @@ def fit(
)
]

# TODO: XGBRegressor supports multi quantile reqression which we could leverage in the future
# see https://xgboost.readthedocs.io/en/latest/python/examples/quantile_regression.html
if self.likelihood == "quantile":
# empty model container in case of multiple calls to fit, e.g. when backtesting
self._model_container.clear()
for quantile in self.quantiles:
obj_func = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = obj_func
if xgb_200_or_above:
self.kwargs["quantile_alpha"] = quantile
else:
objective = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = objective
self.model = xgb.XGBRegressor(**self.kwargs)

super().fit(
Expand Down

0 comments on commit 2d0233f

Please sign in to comment.