Skip to content

Commit

Permalink
Merge branch 'master' into feat/tz_aware_dta
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Nov 4, 2023
2 parents 03f00da + 2d0233f commit deb3903
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added callback `darts.utils.callbacks.TFMProgressBar` to customize at which model stages to display the progress bar. [#2020](https://github.com/unit8co/darts/pull/2020) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to documentation:
- Adapted the example notebooks to properly apply data transformers and avoid look-ahead bias. [#2020](https://github.com/unit8co/darts/pull/2020) by [Samriddhi Singh](https://github.com/SimTheGreat).
- Improvements to Regression Models:
- `XGBModel` now leverages XGBoost's native Quantile Regression support that was released in version 2.0.0 for improved probabilistic forecasts. [#2051](https://github.com/unit8co/darts/pull/2051) by [Dennis Bader](https://github.com/dennisbader).
- Other improvements:
- Added support for time index time zone conversion with parameter `tz` before generating/computing holidays and datetime attributes. Support was added to all Time Axis Encoders (standalone encoders and forecasting models' `add_encoders`, time series generation utils functions `holidays_timeseries()` and `datetime_attribute_timeseries()`, and `TimeSeries` methods `add_datetime_attribute()` and `add_holidays()`. [#2054](https://github.com/unit8co/darts/pull/2054) by [Dennis Bader](https://github.com/dennisbader).

Expand Down
17 changes: 15 additions & 2 deletions darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

logger = get_logger(__name__)

# Check whether we are running xgboost >= 2.0.0 for quantile regression
tokens = xgb.__version__.split(".")
xgb_200_or_above = int(tokens[0]) >= 2


def xgb_quantile_loss(labels: np.ndarray, preds: np.ndarray, quantile: float):
"""Custom loss function for XGBoost to compute quantile loss gradient.
Expand Down Expand Up @@ -184,8 +188,12 @@ def encode_year(idx):
if likelihood in {"poisson"}:
self.kwargs["objective"] = f"count:{likelihood}"
elif likelihood == "quantile":
if xgb_200_or_above:
# leverage built-in Quantile Regression
self.kwargs["objective"] = "reg:quantileerror"
self.quantiles, self._median_idx = self._prepare_quantiles(quantiles)
self._model_container = self._get_model_container()

self._rng = np.random.default_rng(seed=random_state) # seed for sampling

super().__init__(
Expand Down Expand Up @@ -250,12 +258,17 @@ def fit(
)
]

# TODO: XGBRegressor supports multi quantile reqression which we could leverage in the future
# see https://xgboost.readthedocs.io/en/latest/python/examples/quantile_regression.html
if self.likelihood == "quantile":
# empty model container in case of multiple calls to fit, e.g. when backtesting
self._model_container.clear()
for quantile in self.quantiles:
obj_func = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = obj_func
if xgb_200_or_above:
self.kwargs["quantile_alpha"] = quantile
else:
objective = partial(xgb_quantile_loss, quantile=quantile)
self.kwargs["objective"] = objective
self.model = xgb.XGBRegressor(**self.kwargs)

super().fit(
Expand Down

0 comments on commit deb3903

Please sign in to comment.